diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go new file mode 100644 index 0000000000..3529dd9bbb --- /dev/null +++ b/internal/querynode/impl.go @@ -0,0 +1,388 @@ +package querynode + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "go.uber.org/zap" + + "github.com/zilliztech/milvus-distributed/internal/log" + "github.com/zilliztech/milvus-distributed/internal/msgstream" + "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" + "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" + "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" + queryPb "github.com/zilliztech/milvus-distributed/internal/proto/querypb" + "github.com/zilliztech/milvus-distributed/internal/util/typeutil" +) + +func (node *QueryNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { + stats := &internalpb.ComponentStates{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + } + code, ok := node.stateCode.Load().(internalpb.StateCode) + if !ok { + errMsg := "unexpected error in type assertion" + stats.Status = &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: errMsg, + } + return stats, errors.New(errMsg) + } + info := &internalpb.ComponentInfo{ + NodeID: Params.QueryNodeID, + Role: typeutil.QueryNodeRole, + StateCode: code, + } + stats.State = info + return stats, nil +} + +func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + Value: Params.QueryTimeTickChannelName, + }, nil +} + +func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + Value: Params.StatsChannelName, + }, nil +} + +func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQueryChannelRequest) (*commonpb.Status, error) { + if node.searchService == nil || node.searchService.searchMsgStream == nil { + errMsg := "null search service or null search message stream" + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: errMsg, + } + + return status, errors.New(errMsg) + } + + // add request channel + consumeChannels := []string{in.RequestChannelID} + consumeSubName := Params.MsgChannelSubName + node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName) + log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName) + + // add result channel + producerChannels := []string{in.ResultChannelID} + node.searchService.searchResultMsgStream.AsProducer(producerChannels) + log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", ")) + + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + return status, nil +} + +func (node *QueryNode) RemoveQueryChannel(ctx context.Context, in *queryPb.RemoveQueryChannelRequest) (*commonpb.Status, error) { + // if node.searchService == nil || node.searchService.searchMsgStream == nil { + // errMsg := "null search service or null search result message stream" + // status := &commonpb.Status{ + // ErrorCode: commonpb.ErrorCode_UnexpectedError, + // Reason: errMsg, + // } + + // return status, errors.New(errMsg) + // } + + // searchStream, ok := node.searchService.searchMsgStream.(*pulsarms.PulsarMsgStream) + // if !ok { + // errMsg := "type assertion failed for search message stream" + // status := &commonpb.Status{ + // ErrorCode: commonpb.ErrorCode_UnexpectedError, + // Reason: errMsg, + // } + + // return status, errors.New(errMsg) + // } + + // resultStream, ok := node.searchService.searchResultMsgStream.(*pulsarms.PulsarMsgStream) + // if !ok { + // errMsg := "type assertion failed for search result message stream" + // status := &commonpb.Status{ + // ErrorCode: commonpb.ErrorCode_UnexpectedError, + // Reason: errMsg, + // } + + // return status, errors.New(errMsg) + // } + + // // remove request channel + // consumeChannels := []string{in.RequestChannelID} + // consumeSubName := Params.MsgChannelSubName + // // TODO: searchStream.RemovePulsarConsumers(producerChannels) + // searchStream.AsConsumer(consumeChannels, consumeSubName) + + // // remove result channel + // producerChannels := []string{in.ResultChannelID} + // // TODO: resultStream.RemovePulsarProducer(producerChannels) + // resultStream.AsProducer(producerChannels) + + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + return status, nil +} + +func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmChannelsRequest) (*commonpb.Status, error) { + log.Debug("starting WatchDmChannels ...", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs))) + collectionID := in.CollectionID + ds, err := node.getDataSyncService(collectionID) + if err != nil || ds.dmStream == nil { + errMsg := "null data sync service or null data manipulation stream, collectionID = " + fmt.Sprintln(collectionID) + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: errMsg, + } + log.Error(errMsg) + return status, errors.New(errMsg) + } + + switch t := ds.dmStream.(type) { + case *msgstream.MqTtMsgStream: + default: + _ = t + errMsg := "type assertion failed for dm message stream" + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: errMsg, + } + log.Error(errMsg) + return status, errors.New(errMsg) + } + + getUniqueSubName := func() string { + prefixName := Params.MsgChannelSubName + return prefixName + "-" + strconv.FormatInt(collectionID, 10) + } + + // add request channel + consumeChannels := in.ChannelIDs + toSeekInfo := make([]*internalpb.MsgPosition, 0) + toDirSubChannels := make([]string, 0) + + consumeSubName := getUniqueSubName() + + for _, info := range in.Infos { + if len(info.Pos.MsgID) == 0 { + toDirSubChannels = append(toDirSubChannels, info.ChannelID) + continue + } + info.Pos.MsgGroup = consumeSubName + toSeekInfo = append(toSeekInfo, info.Pos) + + log.Debug("prevent inserting segments", zap.String("segmentIDs", fmt.Sprintln(info.ExcludedSegments))) + err := node.replica.addExcludedSegments(collectionID, info.ExcludedSegments) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Error(err.Error()) + return status, err + } + } + + ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName) + for _, pos := range toSeekInfo { + err := ds.dmStream.Seek(pos) + if err != nil { + errMsg := "msgStream seek error :" + err.Error() + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: errMsg, + } + log.Error(errMsg) + return status, errors.New(errMsg) + } + } + log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName) + + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + log.Debug("WatchDmChannels done", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs))) + return status, nil +} + +func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegmentsRequest) (*commonpb.Status, error) { + dct := &loadSegmentsTask{ + baseTask: baseTask{ + ctx: ctx, + done: make(chan error), + }, + req: in, + node: node, + } + + err := node.scheduler.queue.Enqueue(dct) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Error(err.Error()) + return status, err + } + log.Debug("loadSegmentsTask Enqueue done", zap.Any("collectionID", in.CollectionID)) + + err = dct.WaitToFinish() + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Error(err.Error()) + return status, err + } + log.Debug("loadSegmentsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID)) + + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + return status, nil +} + +func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.ReleaseCollectionRequest) (*commonpb.Status, error) { + dct := &releaseCollectionTask{ + baseTask: baseTask{ + ctx: ctx, + done: make(chan error), + }, + req: in, + node: node, + } + + err := node.scheduler.queue.Enqueue(dct) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Error(err.Error()) + return status, err + } + log.Debug("releaseCollectionTask Enqueue done", zap.Any("collectionID", in.CollectionID)) + + err = dct.WaitToFinish() + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Error(err.Error()) + return status, err + } + log.Debug("releaseCollectionTask WaitToFinish done", zap.Any("collectionID", in.CollectionID)) + + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + return status, nil +} + +func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.ReleasePartitionsRequest) (*commonpb.Status, error) { + dct := &releasePartitionsTask{ + baseTask: baseTask{ + ctx: ctx, + done: make(chan error), + }, + req: in, + node: node, + } + + err := node.scheduler.queue.Enqueue(dct) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Error(err.Error()) + return status, err + } + log.Debug("releasePartitionsTask Enqueue done", zap.Any("collectionID", in.CollectionID)) + + err = dct.WaitToFinish() + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Error(err.Error()) + return status, err + } + log.Debug("releasePartitionsTask WaitToFinish done", zap.Any("collectionID", in.CollectionID)) + + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + return status, nil +} + +// deprecated +func (node *QueryNode) ReleaseSegments(ctx context.Context, in *queryPb.ReleaseSegmentsRequest) (*commonpb.Status, error) { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + } + for _, id := range in.SegmentIDs { + err2 := node.loadService.segLoader.replica.removeSegment(id) + if err2 != nil { + // not return, try to release all segments + status.ErrorCode = commonpb.ErrorCode_UnexpectedError + status.Reason = err2.Error() + } + } + return status, nil +} + +func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmentInfoRequest) (*queryPb.GetSegmentInfoResponse, error) { + infos := make([]*queryPb.SegmentInfo, 0) + for _, id := range in.SegmentIDs { + segment, err := node.replica.getSegmentByID(id) + if err != nil { + continue + } + var indexName string + var indexID int64 + // TODO:: segment has multi vec column + if len(segment.indexInfos) > 0 { + for fieldID := range segment.indexInfos { + indexName = segment.getIndexName(fieldID) + indexID = segment.getIndexID(fieldID) + break + } + } + info := &queryPb.SegmentInfo{ + SegmentID: segment.ID(), + CollectionID: segment.collectionID, + PartitionID: segment.partitionID, + MemSize: segment.getMemSize(), + NumRows: segment.getRowCount(), + IndexName: indexName, + IndexID: indexID, + } + infos = append(infos, info) + } + return &queryPb.GetSegmentInfoResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + Infos: infos, + }, nil +} diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 000983a812..af78959c1a 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -17,8 +17,6 @@ import ( "errors" "fmt" "math/rand" - "strconv" - "strings" "sync" "sync/atomic" "time" @@ -29,10 +27,8 @@ import ( "github.com/zilliztech/milvus-distributed/internal/msgstream" "github.com/zilliztech/milvus-distributed/internal/proto/commonpb" "github.com/zilliztech/milvus-distributed/internal/proto/internalpb" - "github.com/zilliztech/milvus-distributed/internal/proto/milvuspb" queryPb "github.com/zilliztech/milvus-distributed/internal/proto/querypb" "github.com/zilliztech/milvus-distributed/internal/types" - "github.com/zilliztech/milvus-distributed/internal/util/typeutil" ) type QueryNode struct { @@ -59,6 +55,7 @@ type QueryNode struct { dataService types.DataService msFactory msgstream.Factory + scheduler *taskScheduler } func NewQueryNode(ctx context.Context, queryNodeID UniqueID, factory msgstream.Factory) *QueryNode { @@ -77,6 +74,7 @@ func NewQueryNode(ctx context.Context, queryNodeID UniqueID, factory msgstream.F msFactory: factory, } + node.scheduler = newTaskScheduler(ctx1) node.replica = newCollectionReplica() node.UpdateStateCode(internalpb.StateCode_Abnormal) return node @@ -96,6 +94,7 @@ func NewQueryNodeWithoutID(ctx context.Context, factory msgstream.Factory) *Quer msFactory: factory, } + node.scheduler = newTaskScheduler(ctx1) node.replica = newCollectionReplica() node.UpdateStateCode(internalpb.StateCode_Abnormal) @@ -167,14 +166,14 @@ func (node *QueryNode) Start() error { // init services and manager node.searchService = newSearchService(node.queryNodeLoopCtx, node.replica, node.msFactory) - //node.metaService = newMetaService(node.queryNodeLoopCtx, node.replica) - node.loadService = newLoadService(node.queryNodeLoopCtx, node.masterService, node.dataService, node.indexService, node.replica) node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica, node.loadService.segLoader.indexLoader.fieldStatsChan, node.msFactory) + // start task scheduler + go node.scheduler.Start() + // start services go node.searchService.start() - //go node.metaService.start() go node.loadService.start() go node.statsService.start() node.UpdateStateCode(internalpb.StateCode_Healthy) @@ -267,366 +266,3 @@ func (node *QueryNode) removeDataSyncService(collectionID UniqueID) { defer node.dsServicesMu.Unlock() delete(node.dataSyncServices, collectionID) } - -func (node *QueryNode) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - stats := &internalpb.ComponentStates{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - } - code, ok := node.stateCode.Load().(internalpb.StateCode) - if !ok { - errMsg := "unexpected error in type assertion" - stats.Status = &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: errMsg, - } - return stats, errors.New(errMsg) - } - info := &internalpb.ComponentInfo{ - NodeID: Params.QueryNodeID, - Role: typeutil.QueryNodeRole, - StateCode: code, - } - stats.State = info - return stats, nil -} - -func (node *QueryNode) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: Params.QueryTimeTickChannelName, - }, nil -} - -func (node *QueryNode) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - Value: Params.StatsChannelName, - }, nil -} - -func (node *QueryNode) AddQueryChannel(ctx context.Context, in *queryPb.AddQueryChannelRequest) (*commonpb.Status, error) { - if node.searchService == nil || node.searchService.searchMsgStream == nil { - errMsg := "null search service or null search message stream" - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: errMsg, - } - - return status, errors.New(errMsg) - } - - // add request channel - consumeChannels := []string{in.RequestChannelID} - consumeSubName := Params.MsgChannelSubName - node.searchService.searchMsgStream.AsConsumer(consumeChannels, consumeSubName) - log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName) - - // add result channel - producerChannels := []string{in.ResultChannelID} - node.searchService.searchResultMsgStream.AsProducer(producerChannels) - log.Debug("querynode AsProducer: " + strings.Join(producerChannels, ", ")) - - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - return status, nil -} - -func (node *QueryNode) RemoveQueryChannel(ctx context.Context, in *queryPb.RemoveQueryChannelRequest) (*commonpb.Status, error) { - // if node.searchService == nil || node.searchService.searchMsgStream == nil { - // errMsg := "null search service or null search result message stream" - // status := &commonpb.Status{ - // ErrorCode: commonpb.ErrorCode_UnexpectedError, - // Reason: errMsg, - // } - - // return status, errors.New(errMsg) - // } - - // searchStream, ok := node.searchService.searchMsgStream.(*pulsarms.PulsarMsgStream) - // if !ok { - // errMsg := "type assertion failed for search message stream" - // status := &commonpb.Status{ - // ErrorCode: commonpb.ErrorCode_UnexpectedError, - // Reason: errMsg, - // } - - // return status, errors.New(errMsg) - // } - - // resultStream, ok := node.searchService.searchResultMsgStream.(*pulsarms.PulsarMsgStream) - // if !ok { - // errMsg := "type assertion failed for search result message stream" - // status := &commonpb.Status{ - // ErrorCode: commonpb.ErrorCode_UnexpectedError, - // Reason: errMsg, - // } - - // return status, errors.New(errMsg) - // } - - // // remove request channel - // consumeChannels := []string{in.RequestChannelID} - // consumeSubName := Params.MsgChannelSubName - // // TODO: searchStream.RemovePulsarConsumers(producerChannels) - // searchStream.AsConsumer(consumeChannels, consumeSubName) - - // // remove result channel - // producerChannels := []string{in.ResultChannelID} - // // TODO: resultStream.RemovePulsarProducer(producerChannels) - // resultStream.AsProducer(producerChannels) - - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - return status, nil -} - -func (node *QueryNode) WatchDmChannels(ctx context.Context, in *queryPb.WatchDmChannelsRequest) (*commonpb.Status, error) { - log.Debug("starting WatchDmChannels ...", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs))) - collectionID := in.CollectionID - ds, err := node.getDataSyncService(collectionID) - if err != nil || ds.dmStream == nil { - errMsg := "null data sync service or null data manipulation stream, collectionID = " + fmt.Sprintln(collectionID) - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: errMsg, - } - log.Error(errMsg) - return status, errors.New(errMsg) - } - - switch t := ds.dmStream.(type) { - case *msgstream.MqTtMsgStream: - default: - _ = t - errMsg := "type assertion failed for dm message stream" - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: errMsg, - } - log.Error(errMsg) - return status, errors.New(errMsg) - } - - getUniqueSubName := func() string { - prefixName := Params.MsgChannelSubName - return prefixName + "-" + strconv.FormatInt(collectionID, 10) - } - - // add request channel - consumeChannels := in.ChannelIDs - toSeekInfo := make([]*internalpb.MsgPosition, 0) - toDirSubChannels := make([]string, 0) - - consumeSubName := getUniqueSubName() - - for _, info := range in.Infos { - if len(info.Pos.MsgID) == 0 { - toDirSubChannels = append(toDirSubChannels, info.ChannelID) - continue - } - info.Pos.MsgGroup = consumeSubName - toSeekInfo = append(toSeekInfo, info.Pos) - - log.Debug("prevent inserting segments", zap.String("segmentIDs", fmt.Sprintln(info.ExcludedSegments))) - err := node.replica.addExcludedSegments(collectionID, info.ExcludedSegments) - if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - } - log.Error(err.Error()) - return status, err - } - } - - ds.dmStream.AsConsumer(toDirSubChannels, consumeSubName) - for _, pos := range toSeekInfo { - err := ds.dmStream.Seek(pos) - if err != nil { - errMsg := "msgStream seek error :" + err.Error() - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: errMsg, - } - log.Error(errMsg) - return status, errors.New(errMsg) - } - } - log.Debug("querynode AsConsumer: " + strings.Join(consumeChannels, ", ") + " : " + consumeSubName) - - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - log.Debug("WatchDmChannels done", zap.String("ChannelIDs", fmt.Sprintln(in.ChannelIDs))) - return status, nil -} - -func (node *QueryNode) LoadSegments(ctx context.Context, in *queryPb.LoadSegmentsRequest) (*commonpb.Status, error) { - // TODO: support db - collectionID := in.CollectionID - partitionID := in.PartitionID - segmentIDs := in.SegmentIDs - fieldIDs := in.FieldIDs - schema := in.Schema - - log.Debug("query node load segment", zap.String("loadSegmentRequest", fmt.Sprintln(in))) - - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - hasCollection := node.replica.hasCollection(collectionID) - hasPartition := node.replica.hasPartition(partitionID) - if !hasCollection { - // loading init - err := node.replica.addCollection(collectionID, schema) - if err != nil { - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - return status, err - } - node.replica.initExcludedSegments(collectionID) - newDS := newDataSyncService(node.queryNodeLoopCtx, node.replica, node.msFactory, collectionID) - // ignore duplicated dataSyncService error - node.addDataSyncService(collectionID, newDS) - ds, err := node.getDataSyncService(collectionID) - if err != nil { - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - return status, err - } - go ds.start() - node.searchService.startSearchCollection(collectionID) - } - if !hasPartition { - err := node.replica.addPartition(collectionID, partitionID) - if err != nil { - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - return status, err - } - } - err := node.replica.enablePartition(partitionID) - if err != nil { - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - return status, err - } - - if len(segmentIDs) == 0 { - return status, nil - } - - err = node.loadService.loadSegmentPassively(collectionID, partitionID, segmentIDs, fieldIDs) - if err != nil { - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - return status, err - } - - log.Debug("LoadSegments done", zap.String("segmentIDs", fmt.Sprintln(in.SegmentIDs))) - return status, nil -} - -func (node *QueryNode) ReleaseCollection(ctx context.Context, in *queryPb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ds, err := node.getDataSyncService(in.CollectionID) - if err == nil && ds != nil { - ds.close() - node.removeDataSyncService(in.CollectionID) - node.replica.removeTSafe(in.CollectionID) - node.replica.removeExcludedSegments(in.CollectionID) - } - - if node.searchService.hasSearchCollection(in.CollectionID) { - node.searchService.stopSearchCollection(in.CollectionID) - } - - err = node.replica.removeCollection(in.CollectionID) - if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - } - return status, err - } - - log.Debug("ReleaseCollection done", zap.Int64("collectionID", in.CollectionID)) - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, nil -} - -func (node *QueryNode) ReleasePartitions(ctx context.Context, in *queryPb.ReleasePartitionsRequest) (*commonpb.Status, error) { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - for _, id := range in.PartitionIDs { - err := node.loadService.segLoader.replica.removePartition(id) - if err != nil { - // not return, try to release all partitions - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err.Error() - } - } - return status, nil -} - -func (node *QueryNode) ReleaseSegments(ctx context.Context, in *queryPb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - } - for _, id := range in.SegmentIDs { - err2 := node.loadService.segLoader.replica.removeSegment(id) - if err2 != nil { - // not return, try to release all segments - status.ErrorCode = commonpb.ErrorCode_UnexpectedError - status.Reason = err2.Error() - } - } - return status, nil -} - -func (node *QueryNode) GetSegmentInfo(ctx context.Context, in *queryPb.GetSegmentInfoRequest) (*queryPb.GetSegmentInfoResponse, error) { - infos := make([]*queryPb.SegmentInfo, 0) - for _, id := range in.SegmentIDs { - segment, err := node.replica.getSegmentByID(id) - if err != nil { - continue - } - var indexName string - var indexID int64 - // TODO:: segment has multi vec column - if len(segment.indexInfos) > 0 { - for fieldID := range segment.indexInfos { - indexName = segment.getIndexName(fieldID) - indexID = segment.getIndexID(fieldID) - break - } - } - info := &queryPb.SegmentInfo{ - SegmentID: segment.ID(), - CollectionID: segment.collectionID, - PartitionID: segment.partitionID, - MemSize: segment.getMemSize(), - NumRows: segment.getRowCount(), - IndexName: indexName, - IndexID: indexID, - } - infos = append(infos, info) - } - return &queryPb.GetSegmentInfoResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - Infos: infos, - }, nil -} diff --git a/internal/querynode/task.go b/internal/querynode/task.go new file mode 100644 index 0000000000..989ef5dfb0 --- /dev/null +++ b/internal/querynode/task.go @@ -0,0 +1,222 @@ +package querynode + +import ( + "context" + "errors" + "fmt" + "math/rand" + + "go.uber.org/zap" + + "github.com/zilliztech/milvus-distributed/internal/log" + queryPb "github.com/zilliztech/milvus-distributed/internal/proto/querypb" +) + +type task interface { + ID() UniqueID // return ReqID + SetID(uid UniqueID) // set ReqID + Timestamp() Timestamp + PreExecute(ctx context.Context) error + Execute(ctx context.Context) error + PostExecute(ctx context.Context) error + WaitToFinish() error + Notify(err error) + OnEnqueue() error +} + +type baseTask struct { + done chan error + ctx context.Context + id UniqueID +} + +type loadSegmentsTask struct { + baseTask + req *queryPb.LoadSegmentsRequest + node *QueryNode +} + +type releaseCollectionTask struct { + baseTask + req *queryPb.ReleaseCollectionRequest + node *QueryNode +} + +type releasePartitionsTask struct { + baseTask + req *queryPb.ReleasePartitionsRequest + node *QueryNode +} + +func (b *baseTask) ID() UniqueID { + return b.id +} + +func (b *baseTask) SetID(uid UniqueID) { + b.id = uid +} + +func (b *baseTask) WaitToFinish() error { + select { + case <-b.ctx.Done(): + return errors.New("task timeout") + case err := <-b.done: + return err + } +} + +func (b *baseTask) Notify(err error) { + b.done <- err +} + +// loadSegmentsTask +func (l *loadSegmentsTask) Timestamp() Timestamp { + return l.req.Base.Timestamp +} + +func (l *loadSegmentsTask) OnEnqueue() error { + if l.req == nil || l.req.Base == nil { + l.SetID(rand.Int63n(100000000000)) + } else { + l.SetID(l.req.Base.MsgID) + } + return nil +} + +func (l *loadSegmentsTask) PreExecute(ctx context.Context) error { + return nil +} + +func (l *loadSegmentsTask) Execute(ctx context.Context) error { + // TODO: support db + collectionID := l.req.CollectionID + partitionID := l.req.PartitionID + segmentIDs := l.req.SegmentIDs + fieldIDs := l.req.FieldIDs + schema := l.req.Schema + + log.Debug("query node load segment", zap.String("loadSegmentRequest", fmt.Sprintln(l.req))) + + hasCollection := l.node.replica.hasCollection(collectionID) + hasPartition := l.node.replica.hasPartition(partitionID) + if !hasCollection { + // loading init + err := l.node.replica.addCollection(collectionID, schema) + if err != nil { + return err + } + l.node.replica.initExcludedSegments(collectionID) + newDS := newDataSyncService(l.node.queryNodeLoopCtx, l.node.replica, l.node.msFactory, collectionID) + // ignore duplicated dataSyncService error + _ = l.node.addDataSyncService(collectionID, newDS) + ds, err := l.node.getDataSyncService(collectionID) + if err != nil { + return err + } + go ds.start() + l.node.searchService.startSearchCollection(collectionID) + } + if !hasPartition { + err := l.node.replica.addPartition(collectionID, partitionID) + if err != nil { + return err + } + } + err := l.node.replica.enablePartition(partitionID) + if err != nil { + return err + } + + if len(segmentIDs) == 0 { + return nil + } + + err = l.node.loadService.loadSegmentPassively(collectionID, partitionID, segmentIDs, fieldIDs) + if err != nil { + return err + } + + log.Debug("LoadSegments done", zap.String("segmentIDs", fmt.Sprintln(l.req.SegmentIDs))) + return nil +} + +func (l *loadSegmentsTask) PostExecute(ctx context.Context) error { + return nil +} + +// releaseCollectionTask +func (r *releaseCollectionTask) Timestamp() Timestamp { + return r.req.Base.Timestamp +} + +func (r *releaseCollectionTask) OnEnqueue() error { + if r.req == nil || r.req.Base == nil { + r.SetID(rand.Int63n(100000000000)) + } else { + r.SetID(r.req.Base.MsgID) + } + return nil +} + +func (r *releaseCollectionTask) PreExecute(ctx context.Context) error { + return nil +} + +func (r *releaseCollectionTask) Execute(ctx context.Context) error { + ds, err := r.node.getDataSyncService(r.req.CollectionID) + if err == nil && ds != nil { + ds.close() + r.node.removeDataSyncService(r.req.CollectionID) + r.node.replica.removeTSafe(r.req.CollectionID) + r.node.replica.removeExcludedSegments(r.req.CollectionID) + } + + if r.node.searchService.hasSearchCollection(r.req.CollectionID) { + r.node.searchService.stopSearchCollection(r.req.CollectionID) + } + + err = r.node.replica.removeCollection(r.req.CollectionID) + if err != nil { + return err + } + + log.Debug("ReleaseCollection done", zap.Int64("collectionID", r.req.CollectionID)) + return nil +} + +func (r *releaseCollectionTask) PostExecute(ctx context.Context) error { + return nil +} + +// releasePartitionsTask +func (r *releasePartitionsTask) Timestamp() Timestamp { + return r.req.Base.Timestamp +} + +func (r *releasePartitionsTask) OnEnqueue() error { + if r.req == nil || r.req.Base == nil { + r.SetID(rand.Int63n(100000000000)) + } else { + r.SetID(r.req.Base.MsgID) + } + return nil +} + +func (r *releasePartitionsTask) PreExecute(ctx context.Context) error { + return nil +} + +func (r *releasePartitionsTask) Execute(ctx context.Context) error { + for _, id := range r.req.PartitionIDs { + err := r.node.loadService.segLoader.replica.removePartition(id) + if err != nil { + // not return, try to release all partitions + log.Error(err.Error()) + } + } + return nil +} + +func (r *releasePartitionsTask) PostExecute(ctx context.Context) error { + return nil +} diff --git a/internal/querynode/task_queue.go b/internal/querynode/task_queue.go new file mode 100644 index 0000000000..e90ee2f2fc --- /dev/null +++ b/internal/querynode/task_queue.go @@ -0,0 +1,153 @@ +package querynode + +import ( + "container/list" + "errors" + "sync" + + "go.uber.org/zap" + + "github.com/zilliztech/milvus-distributed/internal/log" +) + +const maxTaskNum = 1024 + +type taskQueue interface { + utChan() <-chan int + utEmpty() bool + utFull() bool + addUnissuedTask(t task) error + PopUnissuedTask() task + AddActiveTask(t task) + PopActiveTask(tID UniqueID) task + Enqueue(t task) error +} + +type baseTaskQueue struct { + utMu sync.Mutex // guards unissuedTasks + unissuedTasks *list.List + + atMu sync.Mutex // guards activeTasks + activeTasks map[UniqueID]task + + maxTaskNum int64 // maxTaskNum should keep still + utBufChan chan int // to block scheduler + + scheduler *taskScheduler +} + +type loadAndReleaseTaskQueue struct { + baseTaskQueue + mu sync.Mutex +} + +// baseTaskQueue +func (queue *baseTaskQueue) utChan() <-chan int { + return queue.utBufChan +} + +func (queue *baseTaskQueue) utEmpty() bool { + return queue.unissuedTasks.Len() == 0 +} + +func (queue *baseTaskQueue) utFull() bool { + return int64(queue.unissuedTasks.Len()) >= queue.maxTaskNum +} + +func (queue *baseTaskQueue) addUnissuedTask(t task) error { + queue.utMu.Lock() + defer queue.utMu.Unlock() + + if queue.utFull() { + return errors.New("task queue is full") + } + + if queue.unissuedTasks.Len() <= 0 { + queue.unissuedTasks.PushBack(t) + queue.utBufChan <- 1 + return nil + } + + if t.Timestamp() >= queue.unissuedTasks.Back().Value.(task).Timestamp() { + queue.unissuedTasks.PushBack(t) + queue.utBufChan <- 1 + return nil + } + + for e := queue.unissuedTasks.Front(); e != nil; e = e.Next() { + if t.Timestamp() <= e.Value.(task).Timestamp() { + queue.unissuedTasks.InsertBefore(t, e) + queue.utBufChan <- 1 + return nil + } + } + return errors.New("unexpected error in addUnissuedTask") +} + +func (queue *baseTaskQueue) PopUnissuedTask() task { + queue.utMu.Lock() + defer queue.utMu.Unlock() + + if queue.unissuedTasks.Len() <= 0 { + log.Fatal("unissued task list is empty!") + return nil + } + + ft := queue.unissuedTasks.Front() + queue.unissuedTasks.Remove(ft) + + return ft.Value.(task) +} + +func (queue *baseTaskQueue) AddActiveTask(t task) { + queue.atMu.Lock() + defer queue.atMu.Unlock() + + tID := t.ID() + _, ok := queue.activeTasks[tID] + if ok { + log.Warn("queryNode", zap.Int64("task with ID already in active task list!", tID)) + } + + queue.activeTasks[tID] = t +} + +func (queue *baseTaskQueue) PopActiveTask(tID UniqueID) task { + queue.atMu.Lock() + defer queue.atMu.Unlock() + + t, ok := queue.activeTasks[tID] + if ok { + delete(queue.activeTasks, tID) + return t + } + log.Debug("queryNode", zap.Int64("cannot found ID in the active task list!", tID)) + return nil +} + +func (queue *baseTaskQueue) Enqueue(t task) error { + err := t.OnEnqueue() + if err != nil { + return err + } + return queue.addUnissuedTask(t) +} + +// loadAndReleaseTaskQueue +func (queue *loadAndReleaseTaskQueue) Enqueue(t task) error { + queue.mu.Lock() + defer queue.mu.Unlock() + return queue.baseTaskQueue.Enqueue(t) +} + +func newLoadAndReleaseTaskQueue(scheduler *taskScheduler) *loadAndReleaseTaskQueue { + return &loadAndReleaseTaskQueue{ + baseTaskQueue: baseTaskQueue{ + unissuedTasks: list.New(), + activeTasks: make(map[UniqueID]task), + maxTaskNum: maxTaskNum, + utBufChan: make(chan int, maxTaskNum), + scheduler: scheduler, + }, + } +} diff --git a/internal/querynode/task_scheduler.go b/internal/querynode/task_scheduler.go new file mode 100644 index 0000000000..e05b613c3b --- /dev/null +++ b/internal/querynode/task_scheduler.go @@ -0,0 +1,76 @@ +package querynode + +import ( + "context" + "sync" + + "github.com/zilliztech/milvus-distributed/internal/log" +) + +type taskScheduler struct { + ctx context.Context + cancel context.CancelFunc + + wg sync.WaitGroup + queue taskQueue +} + +func newTaskScheduler(ctx context.Context) *taskScheduler { + ctx1, cancel := context.WithCancel(ctx) + s := &taskScheduler{ + ctx: ctx1, + cancel: cancel, + } + s.queue = newLoadAndReleaseTaskQueue(s) + return s +} + +func (s *taskScheduler) processTask(t task, q taskQueue) { + // TODO: ctx? + err := t.PreExecute(s.ctx) + + defer func() { + t.Notify(err) + }() + if err != nil { + log.Error(err.Error()) + return + } + + q.AddActiveTask(t) + defer func() { + q.PopActiveTask(t.ID()) + }() + + err = t.Execute(s.ctx) + if err != nil { + log.Error(err.Error()) + return + } + err = t.PostExecute(s.ctx) +} + +func (s *taskScheduler) loadAndReleaseLoop() { + defer s.wg.Done() + for { + select { + case <-s.ctx.Done(): + return + case <-s.queue.utChan(): + if !s.queue.utEmpty() { + t := s.queue.PopUnissuedTask() + go s.processTask(t, s.queue) + } + } + } +} + +func (s *taskScheduler) Start() { + s.wg.Add(1) + go s.loadAndReleaseLoop() +} + +func (s *taskScheduler) Close() { + s.cancel() + s.wg.Wait() +} diff --git a/internal/queryservice/queryservice.go b/internal/queryservice/queryservice.go index 0441a09cfa..13ed5ab1f8 100644 --- a/internal/queryservice/queryservice.go +++ b/internal/queryservice/queryservice.go @@ -498,6 +498,10 @@ func (qs *QueryService) LoadPartitions(ctx context.Context, req *querypb.LoadPar segment2Node := qs.shuffleSegmentsToQueryNode(toLoadSegmentIDs) for nodeID, assignedSegmentIDs := range segment2Node { loadSegmentRequest := &querypb.LoadSegmentsRequest{ + // TODO: use unique id allocator to assign reqID + Base: &commonpb.MsgBase{ + MsgID: rand.Int63n(10000000000), + }, CollectionID: collectionID, PartitionID: partitionID, SegmentIDs: assignedSegmentIDs,