diff --git a/internal/querycoord/channel_allocator.go b/internal/querycoord/channel_allocator.go new file mode 100644 index 0000000000..00d7b5bd37 --- /dev/null +++ b/internal/querycoord/channel_allocator.go @@ -0,0 +1,76 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package querycoord + +import ( + "context" + "errors" + "sort" + "time" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/querypb" +) + +func defaultChannelAllocatePolicy() ChannelAllocatePolicy { + return shuffleChannelsToQueryNode +} + +// ChannelAllocatePolicy helper function definition to allocate dmChannel to queryNode +type ChannelAllocatePolicy func(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error + +func shuffleChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error { + for { + availableNodes, err := cluster.onlineNodes() + if err != nil { + log.Debug(err.Error()) + if !wait { + return err + } + time.Sleep(1 * time.Second) + continue + } + for _, id := range excludeNodeIDs { + delete(availableNodes, id) + } + + nodeID2NumChannels := make(map[int64]int) + for nodeID := range availableNodes { + numChannels, err := cluster.getNumDmChannels(nodeID) + if err != nil { + delete(availableNodes, nodeID) + continue + } + nodeID2NumChannels[nodeID] = numChannels + } + + if len(availableNodes) > 0 { + nodeIDSlice := make([]int64, 0) + for nodeID := range availableNodes { + nodeIDSlice = append(nodeIDSlice, nodeID) + } + + for _, req := range reqs { + sort.Slice(nodeIDSlice, func(i, j int) bool { + return nodeID2NumChannels[nodeIDSlice[i]] < nodeID2NumChannels[nodeIDSlice[j]] + }) + req.NodeID = nodeIDSlice[0] + nodeID2NumChannels[nodeIDSlice[0]]++ + } + return nil + } + + if !wait { + return errors.New("no queryNode to allocate") + } + } +} diff --git a/internal/querycoord/channel_allocator_test.go b/internal/querycoord/channel_allocator_test.go new file mode 100644 index 0000000000..344b99dc34 --- /dev/null +++ b/internal/querycoord/channel_allocator_test.go @@ -0,0 +1,84 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package querycoord + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" +) + +func TestShuffleChannelsToQueryNode(t *testing.T) { + refreshParams() + baseCtx, cancel := context.WithCancel(context.Background()) + kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + assert.Nil(t, err) + clusterSession := sessionutil.NewSession(context.Background(), Params.MetaRootPath, Params.EtcdEndpoints) + clusterSession.Init(typeutil.QueryCoordRole, Params.Address, true) + meta, err := newMeta(baseCtx, kv, nil, nil) + assert.Nil(t, err) + cluster := &queryNodeCluster{ + ctx: baseCtx, + cancel: cancel, + client: kv, + clusterMeta: meta, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, + } + + firstReq := &querypb.WatchDmChannelsRequest{ + CollectionID: defaultCollectionID, + PartitionID: defaultPartitionID, + Infos: []*datapb.VchannelInfo{ + { + ChannelName: "test1", + }, + }, + } + secondReq := &querypb.WatchDmChannelsRequest{ + CollectionID: defaultCollectionID, + PartitionID: defaultPartitionID, + Infos: []*datapb.VchannelInfo{ + { + ChannelName: "test2", + }, + }, + } + reqs := []*querypb.WatchDmChannelsRequest{firstReq, secondReq} + + err = shuffleChannelsToQueryNode(baseCtx, reqs, cluster, false, nil) + assert.NotNil(t, err) + + node, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + nodeSession := node.session + nodeID := node.queryNodeID + cluster.registerNode(baseCtx, nodeSession, nodeID, disConnect) + waitQueryNodeOnline(cluster, nodeID) + + err = shuffleChannelsToQueryNode(baseCtx, reqs, cluster, false, nil) + assert.Nil(t, err) + + assert.Equal(t, nodeID, firstReq.NodeID) + assert.Equal(t, nodeID, secondReq.NodeID) + + err = removeAllSession() + assert.Nil(t, err) +} diff --git a/internal/querycoord/cluster.go b/internal/querycoord/cluster.go index 192d7de63f..00f7d374ba 100644 --- a/internal/querycoord/cluster.go +++ b/internal/querycoord/cluster.go @@ -69,6 +69,9 @@ type Cluster interface { offlineNodes() (map[int64]Node, error) hasNode(nodeID int64) bool + allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error + allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error + getSessionVersion() int64 getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) []queryNodeGetMetricsResponse @@ -93,22 +96,26 @@ type queryNodeCluster struct { sessionVersion int64 sync.RWMutex - clusterMeta Meta - nodes map[int64]Node - newNodeFn newQueryNodeFn + clusterMeta Meta + nodes map[int64]Node + newNodeFn newQueryNodeFn + segmentAllocator SegmentAllocatePolicy + channelAllocator ChannelAllocatePolicy } func newQueryNodeCluster(ctx context.Context, clusterMeta Meta, kv *etcdkv.EtcdKV, newNodeFn newQueryNodeFn, session *sessionutil.Session) (Cluster, error) { childCtx, cancel := context.WithCancel(ctx) nodes := make(map[int64]Node) c := &queryNodeCluster{ - ctx: childCtx, - cancel: cancel, - client: kv, - session: session, - clusterMeta: clusterMeta, - nodes: nodes, - newNodeFn: newNodeFn, + ctx: childCtx, + cancel: cancel, + client: kv, + session: session, + clusterMeta: clusterMeta, + nodes: nodes, + newNodeFn: newNodeFn, + segmentAllocator: defaultSegAllocatePolicy(), + channelAllocator: defaultChannelAllocatePolicy(), } err := c.reloadFromKV() if err != nil { @@ -642,3 +649,11 @@ func (c *queryNodeCluster) getCollectionInfosByID(ctx context.Context, nodeID in return nil } + +func (c *queryNodeCluster) allocateSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, wait bool, excludeNodeIDs []int64) error { + return c.segmentAllocator(ctx, reqs, c, wait, excludeNodeIDs) +} + +func (c *queryNodeCluster) allocateChannelsToQueryNode(ctx context.Context, reqs []*querypb.WatchDmChannelsRequest, wait bool, excludeNodeIDs []int64) error { + return c.channelAllocator(ctx, reqs, c, wait, excludeNodeIDs) +} diff --git a/internal/querycoord/mock_3rd_component_test.go b/internal/querycoord/mock_3rd_component_test.go index 4355ff84dc..55f3da22d8 100644 --- a/internal/querycoord/mock_3rd_component_test.go +++ b/internal/querycoord/mock_3rd_component_test.go @@ -38,11 +38,12 @@ import ( ) const ( - defaultCollectionID = UniqueID(2021) - defaultPartitionID = UniqueID(2021) - defaultSegmentID = UniqueID(2021) - defaultQueryNodeID = int64(100) - defaultChannelNum = 2 + defaultCollectionID = UniqueID(2021) + defaultPartitionID = UniqueID(2021) + defaultSegmentID = UniqueID(2021) + defaultQueryNodeID = int64(100) + defaultChannelNum = 2 + defaultNumRowPerSegment = 10000 ) func genCollectionSchema(collectionID UniqueID, isBinary bool) *schemapb.CollectionSchema { @@ -347,6 +348,7 @@ func (data *dataCoordMock) GetRecoveryInfo(ctx context.Context, req *datapb.GetR segmentBinlog := &datapb.SegmentBinlogs{ SegmentID: segmentID, FieldBinlogs: fieldBinlogs, + NumOfRows: defaultNumRowPerSegment, } data.Segment2Binlog[segmentID] = segmentBinlog } diff --git a/internal/querycoord/segment_allocator.go b/internal/querycoord/segment_allocator.go new file mode 100644 index 0000000000..5c6a626510 --- /dev/null +++ b/internal/querycoord/segment_allocator.go @@ -0,0 +1,183 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package querycoord + +import ( + "context" + "errors" + "sort" + "time" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/typeutil" +) + +const MaxMemUsagePerNode = 0.9 + +func defaultSegAllocatePolicy() SegmentAllocatePolicy { + return shuffleSegmentsToQueryNodeV2 +} + +// SegmentAllocatePolicy helper function definition to allocate Segment to queryNode +type SegmentAllocatePolicy func(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error + +// shuffleSegmentsToQueryNode shuffle segments to online nodes +// returned are noded id for each segment, which satisfies: +// len(returnedNodeIds) == len(segmentIDs) && segmentIDs[i] is assigned to returnedNodeIds[i] +func shuffleSegmentsToQueryNode(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error { + if len(reqs) == 0 { + return nil + } + + for { + availableNodes, err := cluster.onlineNodes() + if err != nil { + log.Debug(err.Error()) + if !wait { + return err + } + time.Sleep(1 * time.Second) + continue + } + for _, id := range excludeNodeIDs { + delete(availableNodes, id) + } + + nodeID2NumSegemnt := make(map[int64]int) + for nodeID := range availableNodes { + numSegments, err := cluster.getNumSegments(nodeID) + if err != nil { + delete(availableNodes, nodeID) + continue + } + nodeID2NumSegemnt[nodeID] = numSegments + } + + if len(availableNodes) > 0 { + nodeIDSlice := make([]int64, 0) + for nodeID := range availableNodes { + nodeIDSlice = append(nodeIDSlice, nodeID) + } + + for _, req := range reqs { + sort.Slice(nodeIDSlice, func(i, j int) bool { + return nodeID2NumSegemnt[nodeIDSlice[i]] < nodeID2NumSegemnt[nodeIDSlice[j]] + }) + req.DstNodeID = nodeIDSlice[0] + nodeID2NumSegemnt[nodeIDSlice[0]]++ + } + return nil + } + + if !wait { + return errors.New("no queryNode to allocate") + } + } +} + +func shuffleSegmentsToQueryNodeV2(ctx context.Context, reqs []*querypb.LoadSegmentsRequest, cluster Cluster, wait bool, excludeNodeIDs []int64) error { + // key = offset, value = segmentSize + if len(reqs) == 0 { + return nil + } + dataSizePerReq := make([]int64, 0) + for _, req := range reqs { + sizePerRecord, err := typeutil.EstimateSizePerRecord(req.Schema) + if err != nil { + return err + } + sizeOfReq := int64(0) + for _, loadInfo := range req.Infos { + sizeOfReq += int64(sizePerRecord) * loadInfo.NumOfRows + } + dataSizePerReq = append(dataSizePerReq, sizeOfReq) + } + + for { + // online nodes map and totalMem, usedMem, memUsage of every node + totalMem := make(map[int64]uint64) + memUsage := make(map[int64]uint64) + memUsageRate := make(map[int64]float64) + availableNodes, err := cluster.onlineNodes() + if err != nil && !wait { + return errors.New("no online queryNode to allocate") + } + for _, id := range excludeNodeIDs { + delete(availableNodes, id) + } + for nodeID := range availableNodes { + // statistic nodeInfo, used memory, memory usage of every query node + nodeInfo, err := cluster.getNodeInfoByID(nodeID) + if err != nil { + log.Debug("shuffleSegmentsToQueryNodeV2: getNodeInfoByID failed", zap.Error(err)) + delete(availableNodes, nodeID) + continue + } + queryNodeInfo := nodeInfo.(*queryNode) + // avoid allocate segment to node which memUsageRate is high + if queryNodeInfo.memUsageRate >= MaxMemUsagePerNode { + log.Debug("shuffleSegmentsToQueryNodeV2: queryNode memUsageRate large than MaxMemUsagePerNode", zap.Int64("nodeID", nodeID), zap.Float64("current rate", queryNodeInfo.memUsageRate)) + delete(availableNodes, nodeID) + continue + } + + // update totalMem, memUsage, memUsageRate + totalMem[nodeID], memUsage[nodeID], memUsageRate[nodeID] = queryNodeInfo.totalMem, queryNodeInfo.memUsage, queryNodeInfo.memUsageRate + } + if len(availableNodes) > 0 { + nodeIDSlice := make([]int64, 0, len(availableNodes)) + for nodeID := range availableNodes { + nodeIDSlice = append(nodeIDSlice, nodeID) + } + allocateSegmentsDone := true + for offset, sizeOfReq := range dataSizePerReq { + // sort nodes by memUsageRate, low to high + sort.Slice(nodeIDSlice, func(i, j int) bool { + return memUsageRate[nodeIDSlice[i]] < memUsageRate[nodeIDSlice[j]] + }) + findNodeToAllocate := false + // assign load segment request to query node which has least memUsageRate + for _, nodeID := range nodeIDSlice { + memUsageAfterLoad := memUsage[nodeID] + uint64(sizeOfReq) + memUsageRateAfterLoad := float64(memUsageAfterLoad) / float64(totalMem[nodeID]) + if memUsageRateAfterLoad > MaxMemUsagePerNode { + continue + } + reqs[offset].DstNodeID = nodeID + memUsage[nodeID] = memUsageAfterLoad + memUsageRate[nodeID] = memUsageRateAfterLoad + findNodeToAllocate = true + break + } + // the load segment request can't be allocated to any query node + if !findNodeToAllocate { + allocateSegmentsDone = false + break + } + } + + if allocateSegmentsDone { + return nil + } + } + + if wait { + time.Sleep(1 * time.Second) + continue + } else { + return errors.New("no queryNode to allocate") + } + } +} diff --git a/internal/querycoord/segment_allocator_test.go b/internal/querycoord/segment_allocator_test.go new file mode 100644 index 0000000000..65377a42d4 --- /dev/null +++ b/internal/querycoord/segment_allocator_test.go @@ -0,0 +1,110 @@ +// Copyright (C) 2019-2020 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +package querycoord + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" +) + +func TestShuffleSegmentsToQueryNode(t *testing.T) { + refreshParams() + baseCtx, cancel := context.WithCancel(context.Background()) + kv, err := etcdkv.NewEtcdKV(Params.EtcdEndpoints, Params.MetaRootPath) + assert.Nil(t, err) + clusterSession := sessionutil.NewSession(context.Background(), Params.MetaRootPath, Params.EtcdEndpoints) + clusterSession.Init(typeutil.QueryCoordRole, Params.Address, true) + meta, err := newMeta(baseCtx, kv, nil, nil) + assert.Nil(t, err) + cluster := &queryNodeCluster{ + ctx: baseCtx, + cancel: cancel, + client: kv, + clusterMeta: meta, + nodes: make(map[int64]Node), + newNodeFn: newQueryNodeTest, + session: clusterSession, + } + + schema := genCollectionSchema(defaultCollectionID, false) + firstReq := &querypb.LoadSegmentsRequest{ + CollectionID: defaultCollectionID, + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: defaultSegmentID, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + NumOfRows: defaultNumRowPerSegment, + }, + }, + } + secondReq := &querypb.LoadSegmentsRequest{ + CollectionID: defaultCollectionID, + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: defaultSegmentID + 1, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + NumOfRows: defaultNumRowPerSegment, + }, + }, + } + reqs := []*querypb.LoadSegmentsRequest{firstReq, secondReq} + + t.Run("Test shuffleSegmentsWithoutQueryNode", func(t *testing.T) { + err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil) + assert.NotNil(t, err) + }) + + node1, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + node1Session := node1.session + node1ID := node1.queryNodeID + cluster.registerNode(baseCtx, node1Session, node1ID, disConnect) + waitQueryNodeOnline(cluster, node1ID) + + t.Run("Test shuffleSegmentsToQueryNode", func(t *testing.T) { + err = shuffleSegmentsToQueryNode(baseCtx, reqs, cluster, false, nil) + assert.Nil(t, err) + + assert.Equal(t, node1ID, firstReq.DstNodeID) + assert.Equal(t, node1ID, secondReq.DstNodeID) + }) + + node2, err := startQueryNodeServer(baseCtx) + assert.Nil(t, err) + node2Session := node2.session + node2ID := node2.queryNodeID + cluster.registerNode(baseCtx, node2Session, node2ID, disConnect) + waitQueryNodeOnline(cluster, node2ID) + cluster.stopNode(node1ID) + + t.Run("Test shuffleSegmentsToQueryNodeV2", func(t *testing.T) { + err = shuffleSegmentsToQueryNodeV2(baseCtx, reqs, cluster, false, nil) + assert.Nil(t, err) + + assert.Equal(t, node2ID, firstReq.DstNodeID) + assert.Equal(t, node2ID, secondReq.DstNodeID) + }) + + err = removeAllSession() + assert.Nil(t, err) +} diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 1f934401bb..8cfa158f28 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -16,7 +16,6 @@ import ( "errors" "fmt" "sync" - "time" "github.com/golang/protobuf/proto" "go.uber.org/zap" @@ -384,6 +383,7 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo}, Schema: lct.Schema, LoadCondition: querypb.TriggerCondition_grpcRequest, + CollectionID: collectionID, } segmentsToLoad = append(segmentsToLoad, segmentID) @@ -453,12 +453,16 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { } - err = assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil) + internalTasks, err := assignInternalTask(ctx, collectionID, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, false, nil) if err != nil { log.Warn("loadCollectionTask: assign child task failed", zap.Int64("collectionID", collectionID)) lct.setResultInfo(err) return err } + for _, internalTask := range internalTasks { + lct.addChildTask(internalTask) + log.Debug("loadCollectionTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Int64("collectionID", collectionID), zap.Any("task", internalTask)) + } log.Debug("loadCollectionTask: assign child task done", zap.Int64("collectionID", collectionID)) log.Debug("LoadCollection execute done", @@ -735,6 +739,7 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo}, Schema: lpt.Schema, LoadCondition: querypb.TriggerCondition_grpcRequest, + CollectionID: collectionID, } segmentsToLoad = append(segmentsToLoad, segmentID) loadSegmentReqs = append(loadSegmentReqs, loadSegmentReq) @@ -778,12 +783,16 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { } } - err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil) + internalTasks, err := assignInternalTask(ctx, collectionID, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmReqs, watchDeltaReqs, false, nil) if err != nil { log.Warn("loadPartitionTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) lpt.setResultInfo(err) return err } + for _, internalTask := range internalTasks { + lpt.addChildTask(internalTask) + log.Debug("loadPartitionTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Int64("collectionID", collectionID), zap.Any("task", internalTask)) + } log.Debug("loadPartitionTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) log.Debug("loadPartitionTask Execute done", @@ -1053,78 +1062,33 @@ func (lst *loadSegmentTask) postExecute(context.Context) error { } func (lst *loadSegmentTask) reschedule(ctx context.Context) ([]task, error) { - segmentIDs := make([]UniqueID, 0) - collectionID := lst.Infos[0].CollectionID - reScheduledTask := make([]task, 0) + loadSegmentReqs := make([]*querypb.LoadSegmentsRequest, 0) + collectionID := lst.CollectionID for _, info := range lst.Infos { - segmentIDs = append(segmentIDs, info.SegmentID) + msgBase := proto.Clone(lst.Base).(*commonpb.MsgBase) + msgBase.MsgType = commonpb.MsgType_LoadSegments + req := &querypb.LoadSegmentsRequest{ + Base: msgBase, + Infos: []*querypb.SegmentLoadInfo{info}, + Schema: lst.Schema, + LoadCondition: lst.triggerCondition, + SourceNodeID: lst.SourceNodeID, + CollectionID: lst.CollectionID, + } + loadSegmentReqs = append(loadSegmentReqs, req) + } + if lst.excludeNodeIDs == nil { + lst.excludeNodeIDs = []int64{} } lst.excludeNodeIDs = append(lst.excludeNodeIDs, lst.DstNodeID) - segment2Nodes, err := shuffleSegmentsToQueryNode(segmentIDs, lst.cluster, false, lst.excludeNodeIDs) + //TODO:: wait or not according msgType + reScheduledTasks, err := assignInternalTask(ctx, collectionID, lst.getParentTask(), lst.meta, lst.cluster, loadSegmentReqs, nil, nil, false, lst.excludeNodeIDs) if err != nil { log.Error("loadSegment reschedule failed", zap.Int64s("excludeNodes", lst.excludeNodeIDs), zap.Error(err)) return nil, err } - node2segmentInfos := make(map[int64][]*querypb.SegmentLoadInfo) - for index, info := range lst.Infos { - nodeID := segment2Nodes[index] - if _, ok := node2segmentInfos[nodeID]; !ok { - node2segmentInfos[nodeID] = make([]*querypb.SegmentLoadInfo, 0) - } - node2segmentInfos[nodeID] = append(node2segmentInfos[nodeID], info) - } - for nodeID, infos := range node2segmentInfos { - loadSegmentBaseTask := newBaseTask(ctx, lst.getTriggerCondition()) - loadSegmentBaseTask.setParentTask(lst.getParentTask()) - loadSegmentTask := &loadSegmentTask{ - baseTask: loadSegmentBaseTask, - LoadSegmentsRequest: &querypb.LoadSegmentsRequest{ - Base: lst.Base, - DstNodeID: nodeID, - Infos: infos, - Schema: lst.Schema, - LoadCondition: lst.LoadCondition, - }, - meta: lst.meta, - cluster: lst.cluster, - excludeNodeIDs: lst.excludeNodeIDs, - } - reScheduledTask = append(reScheduledTask, loadSegmentTask) - log.Debug("loadSegmentTask: add a loadSegmentTask to RescheduleTasks", zap.Any("task", loadSegmentTask)) - - hasWatchQueryChannel := lst.cluster.hasWatchedQueryChannel(lst.ctx, nodeID, collectionID) - if !hasWatchQueryChannel { - queryChannelInfo, err := lst.meta.getQueryChannelInfoByID(collectionID) - if err != nil { - return nil, err - } - - msgBase := proto.Clone(lst.Base).(*commonpb.MsgBase) - msgBase.MsgType = commonpb.MsgType_WatchQueryChannels - addQueryChannelRequest := &querypb.AddQueryChannelRequest{ - Base: msgBase, - NodeID: nodeID, - CollectionID: collectionID, - RequestChannelID: queryChannelInfo.QueryChannelID, - ResultChannelID: queryChannelInfo.QueryResultChannelID, - GlobalSealedSegments: queryChannelInfo.GlobalSealedSegments, - SeekPosition: queryChannelInfo.SeekPosition, - } - watchQueryChannelBaseTask := newBaseTask(ctx, lst.getTriggerCondition()) - watchQueryChannelBaseTask.setParentTask(lst.getParentTask()) - watchQueryChannelTask := &watchQueryChannelTask{ - baseTask: watchQueryChannelBaseTask, - AddQueryChannelRequest: addQueryChannelRequest, - cluster: lst.cluster, - } - reScheduledTask = append(reScheduledTask, watchQueryChannelTask) - log.Debug("loadSegmentTask: add a watchQueryChannelTask to RescheduleTasks", zap.Any("task", watchQueryChannelTask)) - } - - } - - return reScheduledTask, nil + return reScheduledTasks, nil } type releaseSegmentTask struct { @@ -1273,79 +1237,33 @@ func (wdt *watchDmChannelTask) postExecute(context.Context) error { func (wdt *watchDmChannelTask) reschedule(ctx context.Context) ([]task, error) { collectionID := wdt.CollectionID - channelIDs := make([]string, 0) - reScheduledTask := make([]task, 0) + watchDmChannelReqs := make([]*querypb.WatchDmChannelsRequest, 0) for _, info := range wdt.Infos { - channelIDs = append(channelIDs, info.ChannelName) + msgBase := proto.Clone(wdt.Base).(*commonpb.MsgBase) + msgBase.MsgType = commonpb.MsgType_WatchDmChannels + req := &querypb.WatchDmChannelsRequest{ + Base: msgBase, + CollectionID: collectionID, + PartitionID: wdt.PartitionID, + Infos: []*datapb.VchannelInfo{info}, + Schema: wdt.Schema, + ExcludeInfos: wdt.ExcludeInfos, + } + watchDmChannelReqs = append(watchDmChannelReqs, req) } + if wdt.excludeNodeIDs == nil { + wdt.excludeNodeIDs = []int64{} + } wdt.excludeNodeIDs = append(wdt.excludeNodeIDs, wdt.NodeID) - channel2Nodes, err := shuffleChannelsToQueryNode(channelIDs, wdt.cluster, false, wdt.excludeNodeIDs) + //TODO:: wait or not according msgType + reScheduledTasks, err := assignInternalTask(ctx, collectionID, wdt.parentTask, wdt.meta, wdt.cluster, nil, watchDmChannelReqs, nil, false, wdt.excludeNodeIDs) if err != nil { log.Error("watchDmChannel reschedule failed", zap.Int64s("excludeNodes", wdt.excludeNodeIDs), zap.Error(err)) return nil, err } - node2channelInfos := make(map[int64][]*datapb.VchannelInfo) - for index, info := range wdt.Infos { - nodeID := channel2Nodes[index] - if _, ok := node2channelInfos[nodeID]; !ok { - node2channelInfos[nodeID] = make([]*datapb.VchannelInfo, 0) - } - node2channelInfos[nodeID] = append(node2channelInfos[nodeID], info) - } - for nodeID, infos := range node2channelInfos { - watchDmChannelBaseTask := newBaseTask(ctx, wdt.getTriggerCondition()) - watchDmChannelBaseTask.setParentTask(wdt.getParentTask()) - watchDmChannelTask := &watchDmChannelTask{ - baseTask: watchDmChannelBaseTask, - WatchDmChannelsRequest: &querypb.WatchDmChannelsRequest{ - Base: wdt.Base, - NodeID: nodeID, - CollectionID: wdt.CollectionID, - PartitionID: wdt.PartitionID, - Infos: infos, - Schema: wdt.Schema, - ExcludeInfos: wdt.ExcludeInfos, - }, - meta: wdt.meta, - cluster: wdt.cluster, - excludeNodeIDs: wdt.excludeNodeIDs, - } - reScheduledTask = append(reScheduledTask, watchDmChannelTask) - log.Debug("watchDmChannelTask: add a watchDmChannelTask to RescheduleTasks", zap.Any("task", watchDmChannelTask)) - - hasWatchQueryChannel := wdt.cluster.hasWatchedQueryChannel(wdt.ctx, nodeID, collectionID) - if !hasWatchQueryChannel { - queryChannelInfo, err := wdt.meta.getQueryChannelInfoByID(collectionID) - if err != nil { - return nil, err - } - - msgBase := proto.Clone(wdt.Base).(*commonpb.MsgBase) - msgBase.MsgType = commonpb.MsgType_WatchQueryChannels - addQueryChannelRequest := &querypb.AddQueryChannelRequest{ - Base: msgBase, - NodeID: nodeID, - CollectionID: collectionID, - RequestChannelID: queryChannelInfo.QueryChannelID, - ResultChannelID: queryChannelInfo.QueryResultChannelID, - GlobalSealedSegments: queryChannelInfo.GlobalSealedSegments, - SeekPosition: queryChannelInfo.SeekPosition, - } - watchQueryChannelBaseTask := newBaseTask(ctx, wdt.getTriggerCondition()) - watchQueryChannelBaseTask.setParentTask(wdt.getParentTask()) - watchQueryChannelTask := &watchQueryChannelTask{ - baseTask: watchQueryChannelBaseTask, - AddQueryChannelRequest: addQueryChannelRequest, - cluster: wdt.cluster, - } - reScheduledTask = append(reScheduledTask, watchQueryChannelTask) - log.Debug("watchDmChannelTask: add a watchQueryChannelTask to RescheduleTasks", zap.Any("task", watchQueryChannelTask)) - } - } - - return reScheduledTask, nil + return reScheduledTasks, nil } type watchDeltaChannelTask struct { @@ -1639,12 +1557,16 @@ func (ht *handoffTask) execute(ctx context.Context) error { ht.setResultInfo(err) return err } - err = assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil) + internalTasks, err := assignInternalTask(ctx, collectionID, ht, ht.meta, ht.cluster, []*querypb.LoadSegmentsRequest{loadSegmentReq}, nil, watchDeltaChannelReqs, true, nil) if err != nil { log.Error("handoffTask: assign child task failed", zap.Any("segmentInfo", segmentInfo)) ht.setResultInfo(err) return err } + for _, internalTask := range internalTasks { + ht.addChildTask(internalTask) + log.Debug("handoffTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Int64("segmentID", segmentID), zap.Any("task", internalTask)) + } } else { err = fmt.Errorf("sealed segment has been exist on query node, segmentID is %d", segmentID) log.Error("handoffTask: sealed segment has been exist on query node", zap.Int64("segmentID", segmentID)) @@ -1851,12 +1773,17 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { } } } - err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs) + + internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, watchDmChannelReqs, watchDeltaChannelReqs, true, lbt.SourceNodeIDs) if err != nil { log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) lbt.setResultInfo(err) return err } + for _, internalTask := range internalTasks { + lbt.addChildTask(internalTask) + log.Debug("loadBalanceTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Any("task", internalTask)) + } log.Debug("loadBalanceTask: assign child task done", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) } } @@ -1998,12 +1925,16 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { } // TODO:: assignInternalTask with multi collection - err = assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs) + internalTasks, err := assignInternalTask(ctx, collectionID, lbt, lbt.meta, lbt.cluster, loadSegmentReqs, nil, watchDeltaChannelReqs, false, lbt.SourceNodeIDs) if err != nil { - log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID)) + log.Warn("loadBalanceTask: assign child task failed", zap.Int64("collectionID", collectionID), zap.Int64s("partitionIDs", partitionIDs)) lbt.setResultInfo(err) return err } + for _, internalTask := range internalTasks { + lbt.addChildTask(internalTask) + log.Debug("loadBalanceTask: add a childTask", zap.Int32("task type", int32(internalTask.msgType())), zap.Any("task", internalTask)) + } } log.Debug("loadBalanceTask: assign child task done", zap.Any("balance request", lbt.LoadBalanceRequest)) } @@ -2038,143 +1969,6 @@ func (lbt *loadBalanceTask) postExecute(context.Context) error { return nil } -func shuffleChannelsToQueryNode(dmChannels []string, cluster Cluster, wait bool, excludeNodeIDs []int64) ([]int64, error) { - maxNumChannels := 0 - nodes := make(map[int64]Node) - var err error - for { - nodes, err = cluster.onlineNodes() - if err != nil { - log.Debug(err.Error()) - if !wait { - return nil, err - } - time.Sleep(1 * time.Second) - continue - } - for _, id := range excludeNodeIDs { - delete(nodes, id) - } - if len(nodes) > 0 { - break - } - if !wait { - return nil, errors.New("no queryNode to allocate") - } - } - - for nodeID := range nodes { - numChannels, _ := cluster.getNumDmChannels(nodeID) - if numChannels > maxNumChannels { - maxNumChannels = numChannels - } - } - res := make([]int64, 0) - if len(dmChannels) == 0 { - return res, nil - } - - offset := 0 - loopAll := false - for { - lastOffset := offset - if !loopAll { - for nodeID := range nodes { - numSegments, _ := cluster.getNumSegments(nodeID) - if numSegments >= maxNumChannels { - continue - } - res = append(res, nodeID) - offset++ - if offset == len(dmChannels) { - return res, nil - } - } - } else { - for nodeID := range nodes { - res = append(res, nodeID) - offset++ - if offset == len(dmChannels) { - return res, nil - } - } - } - if lastOffset == offset { - loopAll = true - } - } -} - -// shuffleSegmentsToQueryNode shuffle segments to online nodes -// returned are noded id for each segment, which satisfies: -// len(returnedNodeIds) == len(segmentIDs) && segmentIDs[i] is assigned to returnedNodeIds[i] -func shuffleSegmentsToQueryNode(segmentIDs []UniqueID, cluster Cluster, wait bool, excludeNodeIDs []int64) ([]int64, error) { - maxNumSegments := 0 - nodes := make(map[int64]Node) - var err error - for { - nodes, err = cluster.onlineNodes() - if err != nil { - log.Debug(err.Error()) - if !wait { - return nil, err - } - time.Sleep(1 * time.Second) - continue - } - for _, id := range excludeNodeIDs { - delete(nodes, id) - } - if len(nodes) > 0 { - break - } - if !wait { - return nil, errors.New("no queryNode to allocate") - } - } - for nodeID := range nodes { - numSegments, _ := cluster.getNumSegments(nodeID) - if numSegments > maxNumSegments { - maxNumSegments = numSegments - } - } - res := make([]int64, 0) - - if len(segmentIDs) == 0 { - return res, nil - } - - offset := 0 - loopAll := false - for { - lastOffset := offset - if !loopAll { - for nodeID := range nodes { - numSegments, _ := cluster.getNumSegments(nodeID) - if numSegments >= maxNumSegments { - continue - } - res = append(res, nodeID) - offset++ - if offset == len(segmentIDs) { - return res, nil - } - } - } else { - for nodeID := range nodes { - res = append(res, nodeID) - offset++ - if offset == len(segmentIDs) { - return res, nil - } - } - } - if lastOffset == offset { - loopAll = true - } - } -} - func mergeVChannelInfo(info1 *datapb.VchannelInfo, info2 *datapb.VchannelInfo) *datapb.VchannelInfo { collectionID := info1.CollectionID channelName := info1.ChannelName @@ -2208,53 +2002,45 @@ func mergeVChannelInfo(info1 *datapb.VchannelInfo, info2 *datapb.VchannelInfo) * } func assignInternalTask(ctx context.Context, - collectionID UniqueID, - parentTask task, - meta Meta, - cluster Cluster, + collectionID UniqueID, parentTask task, meta Meta, cluster Cluster, loadSegmentRequests []*querypb.LoadSegmentsRequest, watchDmChannelRequests []*querypb.WatchDmChannelsRequest, watchDeltaChannelRequests []*querypb.WatchDeltaChannelsRequest, - wait bool, excludeNodeIDs []int64) error { + wait bool, excludeNodeIDs []int64) ([]task, error) { sp, _ := trace.StartSpanFromContext(ctx) defer sp.Finish() - segmentsToLoad := make([]UniqueID, 0) - for _, req := range loadSegmentRequests { - segmentsToLoad = append(segmentsToLoad, req.Infos[0].SegmentID) - } - channelsToWatch := make([]string, 0) - for _, req := range watchDmChannelRequests { - channelsToWatch = append(channelsToWatch, req.Infos[0].ChannelName) - } - segment2Nodes, err := shuffleSegmentsToQueryNode(segmentsToLoad, cluster, wait, excludeNodeIDs) + internalTasks := make([]task, 0) + err := cluster.allocateSegmentsToQueryNode(ctx, loadSegmentRequests, wait, excludeNodeIDs) if err != nil { - log.Error("assignInternalTask: segment to node failed", zap.Any("segments map", segment2Nodes), zap.Int64("collectionID", collectionID)) - return err + log.Error("assignInternalTask: assign segment to node failed", zap.Any("load segments requests", loadSegmentRequests)) + return nil, err } - log.Debug("assignInternalTask: segment to node", zap.Any("segments map", segment2Nodes), zap.Int64("collectionID", collectionID)) - watchRequest2Nodes, err := shuffleChannelsToQueryNode(channelsToWatch, cluster, wait, excludeNodeIDs) + log.Debug("assignInternalTask: assign segment to node success", zap.Any("load segments requests", loadSegmentRequests)) + + err = cluster.allocateChannelsToQueryNode(ctx, watchDmChannelRequests, wait, excludeNodeIDs) if err != nil { - log.Error("assignInternalTask: watch request to node failed", zap.Any("request map", watchRequest2Nodes), zap.Int64("collectionID", collectionID)) - return err + log.Error("assignInternalTask: assign dmChannel to node failed", zap.Any("watch dmChannel requests", watchDmChannelRequests)) + return nil, err } - log.Debug("assignInternalTask: watch request to node", zap.Any("request map", watchRequest2Nodes), zap.Int64("collectionID", collectionID)) + log.Debug("assignInternalTask: assign dmChannel to node success", zap.Any("watch dmChannel requests", watchDmChannelRequests)) watchQueryChannelInfo := make(map[int64]bool) node2Segments := make(map[int64][]*querypb.LoadSegmentsRequest) sizeCounts := make(map[int64]int) - for index, nodeID := range segment2Nodes { - sizeOfReq := getSizeOfLoadSegmentReq(loadSegmentRequests[index]) + for _, req := range loadSegmentRequests { + nodeID := req.DstNodeID + sizeOfReq := getSizeOfLoadSegmentReq(req) if _, ok := node2Segments[nodeID]; !ok { node2Segments[nodeID] = make([]*querypb.LoadSegmentsRequest, 0) - node2Segments[nodeID] = append(node2Segments[nodeID], loadSegmentRequests[index]) + node2Segments[nodeID] = append(node2Segments[nodeID], req) sizeCounts[nodeID] = sizeOfReq } else { if sizeCounts[nodeID]+sizeOfReq > MaxSendSizeToEtcd { - node2Segments[nodeID] = append(node2Segments[nodeID], loadSegmentRequests[index]) + node2Segments[nodeID] = append(node2Segments[nodeID], req) sizeCounts[nodeID] = sizeOfReq } else { lastReq := node2Segments[nodeID][len(node2Segments[nodeID])-1] - lastReq.Infos = append(lastReq.Infos, loadSegmentRequests[index].Infos...) + lastReq.Infos = append(lastReq.Infos, req.Infos...) sizeCounts[nodeID] += sizeOfReq } } @@ -2265,18 +2051,10 @@ func assignInternalTask(ctx context.Context, } watchQueryChannelInfo[nodeID] = false } - for _, nodeID := range watchRequest2Nodes { - if cluster.hasWatchedQueryChannel(parentTask.traceCtx(), nodeID, collectionID) { - watchQueryChannelInfo[nodeID] = true - continue - } - watchQueryChannelInfo[nodeID] = false - } for nodeID, loadSegmentsReqs := range node2Segments { for _, req := range loadSegmentsReqs { ctx = opentracing.ContextWithSpan(context.Background(), sp) - req.DstNodeID = nodeID baseTask := newBaseTask(ctx, parentTask.getTriggerCondition()) baseTask.setParentTask(parentTask) loadSegmentTask := &loadSegmentTask{ @@ -2284,10 +2062,9 @@ func assignInternalTask(ctx context.Context, LoadSegmentsRequest: req, meta: meta, cluster: cluster, - excludeNodeIDs: []int64{}, + excludeNodeIDs: excludeNodeIDs, } - parentTask.addChildTask(loadSegmentTask) - log.Debug("assignInternalTask: add a loadSegmentTask childTask", zap.Any("task", loadSegmentTask)) + internalTasks = append(internalTasks, loadSegmentTask) } for _, req := range watchDeltaChannelRequests { @@ -2303,27 +2080,29 @@ func assignInternalTask(ctx context.Context, cluster: cluster, excludeNodeIDs: []int64{}, } - parentTask.addChildTask(watchDeltaTask) - log.Debug("assignInternalTask: add a watchDeltaChannelTask childTask", zap.Any("task", watchDeltaTask)) + internalTasks = append(internalTasks, watchDeltaTask) } - } - for index, nodeID := range watchRequest2Nodes { + for _, req := range watchDmChannelRequests { + nodeID := req.NodeID ctx = opentracing.ContextWithSpan(context.Background(), sp) - watchDmChannelReq := watchDmChannelRequests[index] - watchDmChannelReq.NodeID = nodeID baseTask := newBaseTask(ctx, parentTask.getTriggerCondition()) baseTask.setParentTask(parentTask) watchDmChannelTask := &watchDmChannelTask{ baseTask: baseTask, - WatchDmChannelsRequest: watchDmChannelReq, + WatchDmChannelsRequest: req, meta: meta, cluster: cluster, - excludeNodeIDs: []int64{}, + excludeNodeIDs: excludeNodeIDs, } - parentTask.addChildTask(watchDmChannelTask) - log.Debug("assignInternalTask: add a watchDmChannelTask childTask", zap.Any("task", watchDmChannelTask)) + internalTasks = append(internalTasks, watchDmChannelTask) + + if cluster.hasWatchedQueryChannel(parentTask.traceCtx(), nodeID, collectionID) { + watchQueryChannelInfo[nodeID] = true + continue + } + watchQueryChannelInfo[nodeID] = false } for nodeID, watched := range watchQueryChannelInfo { @@ -2331,7 +2110,7 @@ func assignInternalTask(ctx context.Context, ctx = opentracing.ContextWithSpan(context.Background(), sp) queryChannelInfo, err := meta.getQueryChannelInfoByID(collectionID) if err != nil { - return err + return nil, err } msgBase := proto.Clone(parentTask.msgBase()).(*commonpb.MsgBase) @@ -2353,11 +2132,10 @@ func assignInternalTask(ctx context.Context, AddQueryChannelRequest: addQueryChannelRequest, cluster: cluster, } - parentTask.addChildTask(watchQueryChannelTask) - log.Debug("assignInternalTask: add a watchQueryChannelTask childTask", zap.Any("task", watchQueryChannelTask)) + internalTasks = append(internalTasks, watchQueryChannelTask) } } - return nil + return internalTasks, nil } func getSizeOfLoadSegmentReq(req *querypb.LoadSegmentsRequest) int { diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 0c90ad42ec..a0cc19743b 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -694,10 +694,10 @@ func Test_AssignInternalTask(t *testing.T) { loadSegmentRequests = append(loadSegmentRequests, req) } - err = assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil) + internalTasks, err := assignInternalTask(queryCoord.loopCtx, defaultCollectionID, loadCollectionTask, queryCoord.meta, queryCoord.cluster, loadSegmentRequests, nil, nil, false, nil) assert.Nil(t, err) - assert.NotEqual(t, 1, len(loadCollectionTask.getChildTask())) + assert.NotEqual(t, 1, len(internalTasks)) queryCoord.Stop() err = removeAllSession()