Fix proxy send query before queryNode watchQueryChannel and search timeout (#14640)

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2021-12-30 22:51:21 +08:00 committed by GitHub
parent 4713231671
commit 95f0e9a4a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 173 additions and 3 deletions

View File

@ -296,7 +296,27 @@ func (lct *loadCollectionTask) updateTaskProcess() {
for _, t := range childTasks {
if t.getState() != taskDone {
allDone = false
break
}
// wait watchDeltaChannel and watchQueryChannel task done after loading segment
nodeID := getDstNodeIDByTask(t)
if t.msgType() == commonpb.MsgType_LoadSegments {
if !lct.cluster.hasWatchedDeltaChannel(lct.ctx, nodeID, collectionID) ||
!lct.cluster.hasWatchedQueryChannel(lct.ctx, nodeID, collectionID) {
allDone = false
break
}
}
// wait watchQueryChannel task done after watch dmChannel
if t.msgType() == commonpb.MsgType_WatchDmChannels {
if !lct.cluster.hasWatchedQueryChannel(lct.ctx, nodeID, collectionID) {
allDone = false
break
}
}
}
if allDone {
err := lct.meta.setLoadPercentage(collectionID, 0, 100, querypb.LoadType_loadCollection)
@ -672,6 +692,24 @@ func (lpt *loadPartitionTask) updateTaskProcess() {
if t.getState() != taskDone {
allDone = false
}
// wait watchDeltaChannel and watchQueryChannel task done after loading segment
nodeID := getDstNodeIDByTask(t)
if t.msgType() == commonpb.MsgType_LoadSegments {
if !lpt.cluster.hasWatchedDeltaChannel(lpt.ctx, nodeID, collectionID) ||
!lpt.cluster.hasWatchedQueryChannel(lpt.ctx, nodeID, collectionID) {
allDone = false
break
}
}
// wait watchQueryChannel task done after watching dmChannel
if t.msgType() == commonpb.MsgType_WatchDmChannels {
if !lpt.cluster.hasWatchedQueryChannel(lpt.ctx, nodeID, collectionID) {
allDone = false
break
}
}
}
if allDone {
for _, id := range partitionIDs {
@ -1498,8 +1536,6 @@ func (ht *handoffTask) execute(ctx context.Context) error {
CollectionID: collectionID,
BinlogPaths: segmentBinlogs.FieldBinlogs,
NumOfRows: segmentBinlogs.NumOfRows,
Statslogs: segmentBinlogs.Statslogs,
Deltalogs: segmentBinlogs.Deltalogs,
CompactionFrom: segmentInfo.CompactionFrom,
IndexInfos: segmentInfo.IndexInfos,
}

View File

@ -239,6 +239,43 @@ func genLoadSegmentTask(ctx context.Context, queryCoord *QueryCoord, nodeID int6
return loadSegmentTask
}
func genWatchQueryChannelTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *watchQueryChannelTask {
queryChannelInfo := queryCoord.meta.getQueryChannelInfoByID(defaultCollectionID)
req := &querypb.AddQueryChannelRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels,
},
NodeID: nodeID,
CollectionID: defaultCollectionID,
QueryChannel: queryChannelInfo.QueryChannel,
QueryResultChannel: queryChannelInfo.QueryResultChannel,
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_GrpcRequest)
baseTask.taskID = 200
return &watchQueryChannelTask{
baseTask: baseTask,
AddQueryChannelRequest: req,
cluster: queryCoord.cluster,
}
}
func genWatchDeltaChannelTask(ctx context.Context, queryCoord *QueryCoord, nodeID int64) *watchDeltaChannelTask {
req := &querypb.WatchDeltaChannelsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchDeltaChannels,
},
NodeID: nodeID,
CollectionID: defaultCollectionID,
}
baseTask := newBaseTask(ctx, querypb.TriggerCondition_GrpcRequest)
baseTask.taskID = 300
return &watchDeltaChannelTask{
baseTask: baseTask,
WatchDeltaChannelsRequest: req,
cluster: queryCoord.cluster,
}
}
func waitTaskFinalState(t task, state taskState) {
for {
if t.getState() == state {
@ -1199,3 +1236,66 @@ func TestMergeWatchDeltaChannelInfo(t *testing.T) {
}
assert.ElementsMatch(t, expected, results)
}
func TestUpdateTaskProcessWhenLoadSegment(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
queryCoord.meta.addCollection(defaultCollectionID, querypb.LoadType_loadCollection, genCollectionSchema(defaultCollectionID, false))
loadSegmentTask := genLoadSegmentTask(ctx, queryCoord, node1.queryNodeID)
loadCollectionTask := loadSegmentTask.getParentTask()
queryCoord.scheduler.processTask(loadSegmentTask)
collectionInfo, err := queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage)
watchQueryChannel := genWatchQueryChannelTask(ctx, queryCoord, node1.queryNodeID)
watchQueryChannel.setParentTask(loadCollectionTask)
watchDeltaChannel := genWatchDeltaChannelTask(ctx, queryCoord, node1.queryNodeID)
watchDeltaChannel.setParentTask(loadCollectionTask)
queryCoord.scheduler.processTask(watchQueryChannel)
queryCoord.scheduler.processTask(watchDeltaChannel)
collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage)
err = removeAllSession()
assert.Nil(t, err)
}
func TestUpdateTaskProcessWhenWatchDmChannel(t *testing.T) {
refreshParams()
ctx := context.Background()
queryCoord, err := startQueryCoord(ctx)
assert.Nil(t, err)
node1, err := startQueryNodeServer(ctx)
assert.Nil(t, err)
waitQueryNodeOnline(queryCoord.cluster, node1.queryNodeID)
queryCoord.meta.addCollection(defaultCollectionID, querypb.LoadType_loadCollection, genCollectionSchema(defaultCollectionID, false))
watchDmChannel := genWatchDmChannelTask(ctx, queryCoord, node1.queryNodeID)
loadCollectionTask := watchDmChannel.getParentTask()
queryCoord.scheduler.processTask(watchDmChannel)
collectionInfo, err := queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(0), collectionInfo.InMemoryPercentage)
watchQueryChannel := genWatchQueryChannelTask(ctx, queryCoord, node1.queryNodeID)
watchQueryChannel.setParentTask(loadCollectionTask)
queryCoord.scheduler.processTask(watchQueryChannel)
collectionInfo, err = queryCoord.meta.getCollectionInfoByID(defaultCollectionID)
assert.Nil(t, err)
assert.Equal(t, int64(100), collectionInfo.InMemoryPercentage)
err = removeAllSession()
assert.Nil(t, err)
}

View File

@ -16,7 +16,10 @@
package querycoord
import "github.com/milvus-io/milvus/internal/proto/schemapb"
import (
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
)
func getCompareMapFromSlice(sliceData []int64) map[int64]struct{} {
compareMap := make(map[int64]struct{})
@ -37,3 +40,34 @@ func getVecFieldIDs(schema *schemapb.CollectionSchema) []int64 {
return vecFieldIDs
}
func getDstNodeIDByTask(t task) int64 {
var nodeID int64
switch t.msgType() {
case commonpb.MsgType_LoadSegments:
loadSegment := t.(*loadSegmentTask)
nodeID = loadSegment.DstNodeID
case commonpb.MsgType_WatchDmChannels:
watchDmChannel := t.(*watchDmChannelTask)
nodeID = watchDmChannel.NodeID
case commonpb.MsgType_WatchDeltaChannels:
watchDeltaChannel := t.(*watchDeltaChannelTask)
nodeID = watchDeltaChannel.NodeID
case commonpb.MsgType_WatchQueryChannels:
watchQueryChannel := t.(*watchQueryChannelTask)
nodeID = watchQueryChannel.NodeID
case commonpb.MsgType_ReleaseCollection:
releaseCollection := t.(*releaseCollectionTask)
nodeID = releaseCollection.NodeID
case commonpb.MsgType_ReleasePartitions:
releasePartition := t.(*releasePartitionTask)
nodeID = releasePartition.NodeID
case commonpb.MsgType_ReleaseSegments:
releaseSegment := t.(*releaseSegmentTask)
nodeID = releaseSegment.NodeID
default:
//TODO::
}
return nodeID
}