From 1a336bbbf1e1782d9c548c3efd31fe2e0278559b Mon Sep 17 00:00:00 2001 From: wayblink Date: Fri, 24 Jun 2022 10:54:15 +0800 Subject: [PATCH] Fix: WatchDmChannelsRequest can be too large to save in etcd (#17722) Signed-off-by: wayblink --- internal/querycoord/task.go | 63 ++++++---------- internal/querycoord/task_scheduler.go | 15 +++- internal/querycoord/task_scheduler_test.go | 13 ++++ internal/querycoord/task_test.go | 22 ------ internal/querycoord/task_util.go | 59 +++++++++++++++ internal/querycoord/task_util_test.go | 85 ++++++++++++++++++++++ 6 files changed, 190 insertions(+), 67 deletions(-) create mode 100644 internal/querycoord/task_util.go create mode 100644 internal/querycoord/task_util_test.go diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index 1b090de54c..2597ef0ad6 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -510,18 +510,12 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { for _, info := range mergedDmChannel { msgBase := proto.Clone(lct.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_WatchDmChannels - segmentInfos, err := generateVChannelSegmentInfos(lct.meta, info) - if err != nil { - lct.setResultInfo(err) - return err - } watchRequest := &querypb.WatchDmChannelsRequest{ Base: msgBase, CollectionID: collectionID, //PartitionIDs: toLoadPartitionIDs, - Infos: []*datapb.VchannelInfo{info}, - Schema: lct.Schema, - SegmentInfos: segmentInfos, + Infos: []*datapb.VchannelInfo{info}, + Schema: lct.Schema, LoadMeta: &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: collectionID, @@ -530,7 +524,12 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { ReplicaID: replica.GetReplicaID(), } - watchDmChannelReqs = append(watchDmChannelReqs, watchRequest) + fullWatchRequest, err := generateFullWatchDmChannelsRequest(lct.meta, watchRequest) + if err != nil { + lct.setResultInfo(err) + return err + } + watchDmChannelReqs = append(watchDmChannelReqs, fullWatchRequest) } internalTasks, err := assignInternalTask(ctx, lct, lct.meta, lct.cluster, loadSegmentReqs, watchDmChannelReqs, false, nil, replica.GetNodeIds(), -1, lct.broker) @@ -946,18 +945,12 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { for _, info := range mergedDmChannel { msgBase := proto.Clone(lpt.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_WatchDmChannels - segmentInfos, err := generateVChannelSegmentInfos(lpt.meta, info) - if err != nil { - lpt.setResultInfo(err) - return err - } watchRequest := &querypb.WatchDmChannelsRequest{ Base: msgBase, CollectionID: collectionID, PartitionIDs: partitionIDs, Infos: []*datapb.VchannelInfo{info}, Schema: lpt.Schema, - SegmentInfos: segmentInfos, LoadMeta: &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadPartition, CollectionID: collectionID, @@ -966,7 +959,12 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { ReplicaID: replica.GetReplicaID(), } - watchDmChannelReqs = append(watchDmChannelReqs, watchRequest) + fullWatchRequest, err := generateFullWatchDmChannelsRequest(lpt.meta, watchRequest) + if err != nil { + lpt.setResultInfo(err) + return err + } + watchDmChannelReqs = append(watchDmChannelReqs, fullWatchRequest) } internalTasks, err := assignInternalTask(ctx, lpt, lpt.meta, lpt.cluster, loadSegmentReqs, watchDmChannelReqs, false, nil, replica.GetNodeIds(), -1, lpt.broker) @@ -1404,7 +1402,7 @@ func (wdt *watchDmChannelTask) msgBase() *commonpb.MsgBase { } func (wdt *watchDmChannelTask) marshal() ([]byte, error) { - return proto.Marshal(wdt.WatchDmChannelsRequest) + return proto.Marshal(thinWatchDmChannelsRequest(wdt.WatchDmChannelsRequest)) } func (wdt *watchDmChannelTask) isValid() bool { @@ -2020,17 +2018,11 @@ func (lbt *loadBalanceTask) processNodeDownLoadBalance(ctx context.Context) erro if _, ok := dmChannels[channelName]; ok { msgBase := proto.Clone(lbt.Base).(*commonpb.MsgBase) msgBase.MsgType = commonpb.MsgType_WatchDmChannels - channelSegmentInfos, err := generateVChannelSegmentInfos(lbt.meta, vChannelInfo) - if err != nil { - lbt.setResultInfo(err) - return err - } watchRequest := &querypb.WatchDmChannelsRequest{ Base: msgBase, CollectionID: collectionID, Infos: []*datapb.VchannelInfo{vChannelInfo}, Schema: schema, - SegmentInfos: channelSegmentInfos, LoadMeta: &querypb.LoadMetaInfo{ LoadType: collectionInfo.LoadType, CollectionID: collectionID, @@ -2043,7 +2035,13 @@ func (lbt *loadBalanceTask) processNodeDownLoadBalance(ctx context.Context) erro watchRequest.PartitionIDs = toRecoverPartitionIDs } - watchDmChannelReqs = append(watchDmChannelReqs, watchRequest) + fullWatchRequest, err := generateFullWatchDmChannelsRequest(lbt.meta, watchRequest) + if err != nil { + lbt.setResultInfo(err) + return err + } + + watchDmChannelReqs = append(watchDmChannelReqs, fullWatchRequest) } } @@ -2576,20 +2574,3 @@ func mergeDmChannelInfo(infos []*datapb.VchannelInfo) map[string]*datapb.Vchanne return minPositions } - -// generateVChannelSegmentInfos returns a map contains -// all the segment infos(flushed, unflushed and dropped) in a vChannel. -func generateVChannelSegmentInfos(m Meta, vChannel *datapb.VchannelInfo) (map[int64]*datapb.SegmentInfo, error) { - segmentIds := make([]int64, 0) - segmentIds = append(append(append(segmentIds, vChannel.FlushedSegmentIds...), vChannel.UnflushedSegmentIds...), vChannel.DroppedSegmentIds...) - segmentInfos, err := m.getDataSegmentInfosByIDs(segmentIds) - if err != nil { - log.Error("Get Vchannel SegmentInfos failed", zap.String("vChannel", vChannel.String()), zap.Error(err)) - return nil, err - } - segmentDict := make(map[int64]*datapb.SegmentInfo) - for _, info := range segmentInfos { - segmentDict[info.ID] = info - } - return segmentDict, nil -} diff --git a/internal/querycoord/task_scheduler.go b/internal/querycoord/task_scheduler.go index b91c00f418..1663fbe054 100644 --- a/internal/querycoord/task_scheduler.go +++ b/internal/querycoord/task_scheduler.go @@ -354,14 +354,18 @@ func (scheduler *TaskScheduler) unmarshalTask(taskID UniqueID, t string) (task, newTask = releaseSegmentTask case commonpb.MsgType_WatchDmChannels: //TODO::trigger condition may be different - loadReq := querypb.WatchDmChannelsRequest{} - err = proto.Unmarshal([]byte(t), &loadReq) + req := querypb.WatchDmChannelsRequest{} + err = proto.Unmarshal([]byte(t), &req) + if err != nil { + return nil, err + } + fullReq, err := generateFullWatchDmChannelsRequest(scheduler.meta, &req) if err != nil { return nil, err } watchDmChannelTask := &watchDmChannelTask{ baseTask: baseTask, - WatchDmChannelsRequest: &loadReq, + WatchDmChannelsRequest: fullReq, cluster: scheduler.cluster, meta: scheduler.meta, excludeNodeIDs: []int64{}, @@ -484,7 +488,10 @@ func (scheduler *TaskScheduler) processTask(t task) error { default: //TODO:: } - log.Debug("updateKVFn: the size of internal request", zap.Int("size", protoSize), zap.Int64("taskID", childTask.getTaskID())) + log.Debug("updateKVFn: the size of internal request", + zap.Int("size", protoSize), + zap.Int64("taskID", childTask.getTaskID()), + zap.String("type", childTask.msgType().String())) blobs, err := childTask.marshal() if err != nil { return err diff --git a/internal/querycoord/task_scheduler_test.go b/internal/querycoord/task_scheduler_test.go index 97ba5f5e18..049b3d708c 100644 --- a/internal/querycoord/task_scheduler_test.go +++ b/internal/querycoord/task_scheduler_test.go @@ -199,9 +199,14 @@ func TestUnMarshalTask(t *testing.T) { defer etcdCli.Close() kv := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath) baseCtx, cancel := context.WithCancel(context.Background()) + dataCoord := &dataCoordMock{} + meta := &MetaReplica{ + dataCoord: dataCoord, + } taskScheduler := &TaskScheduler{ ctx: baseCtx, cancel: cancel, + meta: meta, } t.Run("Test loadCollectionTask", func(t *testing.T) { @@ -350,6 +355,14 @@ func TestUnMarshalTask(t *testing.T) { task, err := taskScheduler.unmarshalTask(1006, value) assert.Nil(t, err) assert.Equal(t, task.msgType(), commonpb.MsgType_WatchDmChannels) + + dataCoord.returnError = true + defer func() { + dataCoord.returnError = false + }() + task2, err := taskScheduler.unmarshalTask(1006, value) + assert.Error(t, err) + assert.Nil(t, task2) }) t.Run("Test watchDeltaChannelTask", func(t *testing.T) { diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 7215ec2481..ed26fd7f01 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -1456,25 +1456,3 @@ func Test_LoadSegment(t *testing.T) { err = removeAllSession() assert.Nil(t, err) } - -func TestGenerateVChannelSegmentInfos(t *testing.T) { - dataCoord := &dataCoordMock{} - meta := &MetaReplica{ - dataCoord: dataCoord, - } - - deltaChannel := &datapb.VchannelInfo{ - CollectionID: defaultCollectionID, - ChannelName: "delta-channel1", - UnflushedSegmentIds: []int64{1}, - } - - segmentDict, err := generateVChannelSegmentInfos(meta, deltaChannel) - assert.Nil(t, err) - assert.Equal(t, 1, len(segmentDict)) - - dataCoord.returnError = true - segmentDict2, err := generateVChannelSegmentInfos(meta, deltaChannel) - assert.Error(t, err) - assert.Empty(t, segmentDict2) -} diff --git a/internal/querycoord/task_util.go b/internal/querycoord/task_util.go new file mode 100644 index 0000000000..0ca67f53be --- /dev/null +++ b/internal/querycoord/task_util.go @@ -0,0 +1,59 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querycoord + +import ( + "github.com/golang/protobuf/proto" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "go.uber.org/zap" +) + +// generateFullWatchDmChannelsRequest fill the WatchDmChannelsRequest by get segment infos from Meta +func generateFullWatchDmChannelsRequest(m Meta, request *querypb.WatchDmChannelsRequest) (*querypb.WatchDmChannelsRequest, error) { + cloned := proto.Clone(request).(*querypb.WatchDmChannelsRequest) + vChannels := cloned.GetInfos() + + // fill segmentInfos + segmentIds := make([]int64, 0) + for _, vChannel := range vChannels { + segmentIds = append(segmentIds, vChannel.FlushedSegmentIds...) + segmentIds = append(segmentIds, vChannel.UnflushedSegmentIds...) + segmentIds = append(segmentIds, vChannel.DroppedSegmentIds...) + } + segmentInfos, err := m.getDataSegmentInfosByIDs(segmentIds) + if err != nil { + log.Error("Get Vchannel SegmentInfos failed", zap.Error(err)) + return nil, err + } + segmentDict := make(map[int64]*datapb.SegmentInfo) + for _, info := range segmentInfos { + segmentDict[info.ID] = info + } + cloned.SegmentInfos = segmentDict + + return cloned, err +} + +// thinWatchDmChannelsRequest will return a thin version of WatchDmChannelsRequest +// the thin version is used for storage because the complete version may be too large +func thinWatchDmChannelsRequest(request *querypb.WatchDmChannelsRequest) *querypb.WatchDmChannelsRequest { + cloned := proto.Clone(request).(*querypb.WatchDmChannelsRequest) + cloned.SegmentInfos = make(map[int64]*datapb.SegmentInfo) + return cloned +} diff --git a/internal/querycoord/task_util_test.go b/internal/querycoord/task_util_test.go new file mode 100644 index 0000000000..cf0de21a52 --- /dev/null +++ b/internal/querycoord/task_util_test.go @@ -0,0 +1,85 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querycoord + +import ( + "testing" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/stretchr/testify/assert" +) + +func TestGenerateFullWatchDmChannelsRequest(t *testing.T) { + dataCoord := &dataCoordMock{} + meta := &MetaReplica{ + dataCoord: dataCoord, + } + + deltaChannel := &datapb.VchannelInfo{ + CollectionID: defaultCollectionID, + ChannelName: "delta-channel1", + UnflushedSegmentIds: []int64{1}, + } + + watchDmChannelsRequest := &querypb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + }, + Infos: []*datapb.VchannelInfo{deltaChannel}, + NodeID: 1, + } + + fullWatchDmChannelsRequest, err := generateFullWatchDmChannelsRequest(meta, watchDmChannelsRequest) + assert.Nil(t, err) + assert.NotEmpty(t, fullWatchDmChannelsRequest.GetSegmentInfos()) + + dataCoord.returnError = true + fullWatchDmChannelsRequest2, err := generateFullWatchDmChannelsRequest(meta, watchDmChannelsRequest) + assert.Error(t, err) + assert.Empty(t, fullWatchDmChannelsRequest2.GetSegmentInfos()) +} + +func TestThinWatchDmChannelsRequest(t *testing.T) { + var segmentID int64 = 1 + + deltaChannel := &datapb.VchannelInfo{ + CollectionID: defaultCollectionID, + ChannelName: "delta-channel1", + UnflushedSegmentIds: []int64{segmentID}, + } + + segment := &datapb.SegmentInfo{ + ID: segmentID, + } + + segmentInfos := make(map[int64]*datapb.SegmentInfo) + segmentInfos[segmentID] = segment + + watchDmChannelsRequest := &querypb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + }, + Infos: []*datapb.VchannelInfo{deltaChannel}, + NodeID: 1, + SegmentInfos: segmentInfos, + } + + thinReq := thinWatchDmChannelsRequest(watchDmChannelsRequest) + assert.Empty(t, thinReq.GetSegmentInfos()) +}