From 95f0e9a4a6cfaec50bbd1deea356132b03e7d8b0 Mon Sep 17 00:00:00 2001 From: xige-16 Date: Thu, 30 Dec 2021 22:51:21 +0800 Subject: [PATCH] Fix proxy send query before queryNode watchQueryChannel and search timeout (#14640) Signed-off-by: xige-16 --- internal/querycoord/task.go | 40 ++++++++++++- internal/querycoord/task_test.go | 100 +++++++++++++++++++++++++++++++ internal/querycoord/util.go | 36 ++++++++++- 3 files changed, 173 insertions(+), 3 deletions(-) diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index bad1dd38ab..f3ee905dec 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -296,7 +296,27 @@ func (lct *loadCollectionTask) updateTaskProcess() { for _, t := range childTasks { if t.getState() != taskDone { allDone = false + break } + + // wait watchDeltaChannel and watchQueryChannel task done after loading segment + nodeID := getDstNodeIDByTask(t) + if t.msgType() == commonpb.MsgType_LoadSegments { + if !lct.cluster.hasWatchedDeltaChannel(lct.ctx, nodeID, collectionID) || + !lct.cluster.hasWatchedQueryChannel(lct.ctx, nodeID, collectionID) { + allDone = false + break + } + } + + // wait watchQueryChannel task done after watch dmChannel + if t.msgType() == commonpb.MsgType_WatchDmChannels { + if !lct.cluster.hasWatchedQueryChannel(lct.ctx, nodeID, collectionID) { + allDone = false + break + } + } + } if allDone { err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_loadCollection) @@ -672,6 +692,24 @@ func (lpt *loadPartitionTask) updateTaskProcess() { if t.getState() != taskDone { allDone = false } + + // wait watchDeltaChannel and watchQueryChannel task done after loading segment + nodeID := getDstNodeIDByTask(t) + if t.msgType() == commonpb.MsgType_LoadSegments { + if !lpt.cluster.hasWatchedDeltaChannel(lpt.ctx, nodeID, collectionID) || + !lpt.cluster.hasWatchedQueryChannel(lpt.ctx, nodeID, collectionID) { + allDone = false + break + } + } + + // wait watchQueryChannel task done after watching dmChannel + if t.msgType() == commonpb.MsgType_WatchDmChannels { + if !lpt.cluster.hasWatchedQueryChannel(lpt.ctx, nodeID, collectionID) { + allDone = false + break + } + } } if allDone { for _, id := range partitionIDs { @@ -1498,8 +1536,6 @@ func (ht *handoffTask) execute(ctx context.Context) error { CollectionID: collectionID, BinlogPaths: segmentBinlogs.FieldBinlogs, NumOfRows: segmentBinlogs.NumOfRows, - Statslogs: segmentBinlogs.Statslogs, - Deltalogs: segmentBinlogs.Deltalogs, CompactionFrom: segmentInfo.CompactionFrom, IndexInfos: segmentInfo.IndexInfos, } diff --git a/internal/querycoord/task_test.go b/internal/querycoord/task_test.go index 2a7ba6851d..fb787db4c7 100644 --- a/internal/querycoord/task_test.go +++ b/internal/querycoord/task_test.go @@ -239,6 +239,43 @@ func genLoadSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID int6 return loadSegmentTask } +func genWatchQueryChannelTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *watchQueryChannelTask { + queryChannelInfo := queryCoord.meta.getQueryChannelInfoByID(defaultCollectionID) + req := &querypb.AddQueryChannelRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchQueryChannels, + }, + NodeID: nodeID, + CollectionID: defaultCollectionID, + QueryChannel: queryChannelInfo.QueryChannel, + QueryResultChannel: queryChannelInfo.QueryResultChannel, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_GrpcRequest) + baseTask.taskID = 200 + return &watchQueryChannelTask{ + baseTask: baseTask, + AddQueryChannelRequest: req, + cluster: queryCoord.cluster, + } +} + +func genWatchDeltaChannelTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *watchDeltaChannelTask { + req := &querypb.WatchDeltaChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDeltaChannels, + }, + NodeID: nodeID, + CollectionID: defaultCollectionID, + } + baseTask := newBaseTask(ctx, querypb.TriggerCondition_GrpcRequest) + baseTask.taskID = 300 + return &watchDeltaChannelTask{ + baseTask: baseTask, + WatchDeltaChannelsRequest: req, + cluster: queryCoord.cluster, + } +} + func waitTaskFinalState(t task, state taskState) { for { if t.getState() == state { @@ -1199,3 +1236,66 @@ func TestMergeWatchDeltaChannelInfo(t *testing.T) { } assert.ElementsMatch(t, expected, results) } + +func TestUpdateTaskProcessWhenLoadSegment(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + queryCoord.meta.addCollection(defaultCollectionID, querypb.LoadType_loadCollection, genCollectionSchema(defaultCollectionID, false)) + + loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node1.queryNodeID) + loadCollectionTask := loadSegmentTask.getParentTask() + + queryCoord.scheduler.processTask(loadSegmentTask) + collectionInfo, err := queryCoord.meta.getCollectionInfoByID(defaultCollectionID) + assert.Nil(t, err) + assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage) + + watchQueryChannel := genWatchQueryChannelTask(ctx, queryCoord, node1.queryNodeID) + watchQueryChannel.setParentTask(loadCollectionTask) + watchDeltaChannel := genWatchDeltaChannelTask(ctx, queryCoord, node1.queryNodeID) + watchDeltaChannel.setParentTask(loadCollectionTask) + queryCoord.scheduler.processTask(watchQueryChannel) + queryCoord.scheduler.processTask(watchDeltaChannel) + collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID) + assert.Nil(t, err) + assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage) + + err = removeAllSession() + assert.Nil(t, err) +} + +func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) { + refreshParams() + ctx := context.Background() + queryCoord, err := startQueryCoord(ctx) + assert.Nil(t, err) + + node1, err := startQueryNodeServer(ctx) + assert.Nil(t, err) + waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID) + queryCoord.meta.addCollection(defaultCollectionID, querypb.LoadType_loadCollection, genCollectionSchema(defaultCollectionID, false)) + + watchDmChannel := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID) + loadCollectionTask := watchDmChannel.getParentTask() + + queryCoord.scheduler.processTask(watchDmChannel) + collectionInfo, err := queryCoord.meta.getCollectionInfoByID(defaultCollectionID) + assert.Nil(t, err) + assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage) + + watchQueryChannel := genWatchQueryChannelTask(ctx, queryCoord, node1.queryNodeID) + watchQueryChannel.setParentTask(loadCollectionTask) + queryCoord.scheduler.processTask(watchQueryChannel) + collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID) + assert.Nil(t, err) + assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage) + + err = removeAllSession() + assert.Nil(t, err) +} diff --git a/internal/querycoord/util.go b/internal/querycoord/util.go index a8d7a234d9..c8a7594509 100644 --- a/internal/querycoord/util.go +++ b/internal/querycoord/util.go @@ -16,7 +16,10 @@ package querycoord -import "github.com/milvus-io/milvus/internal/proto/schemapb" +import ( + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/schemapb" +) func getCompareMapFromSlice(sliceData []int64) map[int64]struct{} { compareMap := make(map[int64]struct{}) @@ -37,3 +40,34 @@ func getVecFieldIDs(schema *schemapb.CollectionSchema) []int64 { return vecFieldIDs } + +func getDstNodeIDByTask(t task) int64 { + var nodeID int64 + switch t.msgType() { + case commonpb.MsgType_LoadSegments: + loadSegment := t.(*loadSegmentTask) + nodeID = loadSegment.DstNodeID + case commonpb.MsgType_WatchDmChannels: + watchDmChannel := t.(*watchDmChannelTask) + nodeID = watchDmChannel.NodeID + case commonpb.MsgType_WatchDeltaChannels: + watchDeltaChannel := t.(*watchDeltaChannelTask) + nodeID = watchDeltaChannel.NodeID + case commonpb.MsgType_WatchQueryChannels: + watchQueryChannel := t.(*watchQueryChannelTask) + nodeID = watchQueryChannel.NodeID + case commonpb.MsgType_ReleaseCollection: + releaseCollection := t.(*releaseCollectionTask) + nodeID = releaseCollection.NodeID + case commonpb.MsgType_ReleasePartitions: + releasePartition := t.(*releasePartitionTask) + nodeID = releasePartition.NodeID + case commonpb.MsgType_ReleaseSegments: + releaseSegment := t.(*releaseSegmentTask) + nodeID = releaseSegment.NodeID + default: + //TODO:: + } + + return nodeID +}