From 4136009a9ab2a62c2f72392aec6fec72763e5fbe Mon Sep 17 00:00:00 2001 From: Xiaofan <83447078+xiaofan-luan@users.noreply.github.com> Date: Mon, 31 Oct 2022 13:55:33 +0800 Subject: [PATCH] Support load segments/channels in parallel (#20036) Signed-off-by: xiaofan-luan Signed-off-by: xiaofan-luan --- internal/querynode/collection.go | 71 +++ internal/querynode/data_sync_service.go | 26 +- internal/querynode/data_sync_service_test.go | 51 -- internal/querynode/impl.go | 150 ++++-- internal/querynode/log_segment_task.go | 226 ++++++++ internal/querynode/meta_replica.go | 3 +- internal/querynode/mock_test.go | 13 +- internal/querynode/query_node.go | 14 + internal/querynode/query_shard_service.go | 1 + internal/querynode/segment_loader.go | 26 +- internal/querynode/task.go | 509 ------------------- internal/querynode/watch_dm_channels_task.go | 354 +++++++++++++ internal/util/lock/key_lock.go | 131 +++++ internal/util/lock/key_lock_test.go | 69 +++ 14 files changed, 1010 insertions(+), 634 deletions(-) create mode 100644 internal/querynode/log_segment_task.go create mode 100644 internal/querynode/watch_dm_channels_task.go create mode 100644 internal/util/lock/key_lock.go create mode 100644 internal/util/lock/key_lock_test.go diff --git a/internal/querynode/collection.go b/internal/querynode/collection.go index 130b727ee6..014bcec583 100644 --- a/internal/querynode/collection.go +++ b/internal/querynode/collection.go @@ -50,6 +50,7 @@ type Collection struct { partitionIDs []UniqueID schema *schemapb.CollectionSchema + // TODO, remove delta channels channelMu sync.RWMutex vChannels []Channel pChannels []Channel @@ -225,6 +226,41 @@ func (c *Collection) getVDeltaChannels() []Channel { return tmpChannels } +func (c *Collection) AddChannels(toLoadChannels []Channel, VPChannels map[string]string) []Channel { + c.channelMu.Lock() + defer c.channelMu.Unlock() + + retVChannels := []Channel{} + for _, toLoadChannel := range toLoadChannels { + if !c.isVChannelExist(toLoadChannel) { + retVChannels = append(retVChannels, toLoadChannel) + c.vChannels = append(c.vChannels, toLoadChannel) + if !c.isPChannelExist(VPChannels[toLoadChannel]) { + c.pChannels = append(c.pChannels, VPChannels[toLoadChannel]) + } + } + } + return retVChannels +} + +func (c *Collection) isVChannelExist(channel string) bool { + for _, vChannel := range c.vChannels { + if vChannel == channel { + return true + } + } + return false +} + +func (c *Collection) isPChannelExist(channel string) bool { + for _, vChannel := range c.pChannels { + if vChannel == channel { + return true + } + } + return false +} + // addVChannels add virtual channels to collection func (c *Collection) addVDeltaChannels(channels []Channel) { c.channelMu.Lock() @@ -268,6 +304,41 @@ func (c *Collection) removeVDeltaChannel(channel Channel) { metrics.QueryNodeNumDeltaChannels.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Sub(float64(len(c.vDeltaChannels))) } +func (c *Collection) AddVDeltaChannels(toLoadChannels []Channel, VPChannels map[string]string) []Channel { + c.channelMu.Lock() + defer c.channelMu.Unlock() + + retVDeltaChannels := []Channel{} + for _, toLoadChannel := range toLoadChannels { + if !c.isVDeltaChannelExist(toLoadChannel) { + retVDeltaChannels = append(retVDeltaChannels, toLoadChannel) + c.vDeltaChannels = append(c.vDeltaChannels, toLoadChannel) + if !c.isPDeltaChannelExist(VPChannels[toLoadChannel]) { + c.pDeltaChannels = append(c.pDeltaChannels, VPChannels[toLoadChannel]) + } + } + } + return retVDeltaChannels +} + +func (c *Collection) isVDeltaChannelExist(channel string) bool { + for _, vDeltaChanel := range c.vDeltaChannels { + if vDeltaChanel == channel { + return true + } + } + return false +} + +func (c *Collection) isPDeltaChannelExist(channel string) bool { + for _, vChannel := range c.pDeltaChannels { + if vChannel == channel { + return true + } + } + return false +} + // setReleaseTime records when collection is released func (c *Collection) setReleaseTime(t Timestamp, released bool) { c.releaseMu.Lock() diff --git a/internal/querynode/data_sync_service.go b/internal/querynode/data_sync_service.go index 221b4ef748..0159f2de36 100644 --- a/internal/querynode/data_sync_service.go +++ b/internal/querynode/data_sync_service.go @@ -49,32 +49,13 @@ func (dsService *dataSyncService) getFlowGraphNum() int { return len(dsService.dmlChannel2FlowGraph) + len(dsService.deltaChannel2FlowGraph) } -// checkReplica used to check replica info before init flow graph, it's a private method of dataSyncService -func (dsService *dataSyncService) checkReplica(collectionID UniqueID) error { - // check if the collection exists - coll, err := dsService.metaReplica.getCollectionByID(collectionID) - if err != nil { - return err - } - for _, channel := range coll.getVChannels() { - if _, err := dsService.tSafeReplica.getTSafe(channel); err != nil { - return fmt.Errorf("getTSafe failed, err = %s", err) - } - } - for _, channel := range coll.getVDeltaChannels() { - if _, err := dsService.tSafeReplica.getTSafe(channel); err != nil { - return fmt.Errorf("getTSafe failed, err = %s", err) - } - } - return nil -} - // addFlowGraphsForDMLChannels add flowGraphs to dmlChannel2FlowGraph func (dsService *dataSyncService) addFlowGraphsForDMLChannels(collectionID UniqueID, dmlChannels []string) (map[string]*queryNodeFlowGraph, error) { dsService.mu.Lock() defer dsService.mu.Unlock() - if err := dsService.checkReplica(collectionID); err != nil { + _, err := dsService.metaReplica.getCollectionByID(collectionID) + if err != nil { return nil, err } @@ -118,7 +99,8 @@ func (dsService *dataSyncService) addFlowGraphsForDeltaChannels(collectionID Uni dsService.mu.Lock() defer dsService.mu.Unlock() - if err := dsService.checkReplica(collectionID); err != nil { + _, err := dsService.metaReplica.getCollectionByID(collectionID) + if err != nil { return nil, err } diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index 5962f781c6..790e606545 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -153,57 +153,6 @@ func TestDataSyncService_DeltaFlowGraphs(t *testing.T) { }) } -func TestDataSyncService_checkReplica(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - replica, err := genSimpleReplica() - assert.NoError(t, err) - - fac := genFactory() - assert.NoError(t, err) - - tSafe := newTSafeReplica() - dataSyncService := newDataSyncService(ctx, replica, tSafe, fac) - assert.NotNil(t, dataSyncService) - defer dataSyncService.close() - - t.Run("test checkReplica", func(t *testing.T) { - err = dataSyncService.checkReplica(defaultCollectionID) - assert.NoError(t, err) - }) - - t.Run("test collection doesn't exist", func(t *testing.T) { - err = dataSyncService.metaReplica.removeCollection(defaultCollectionID) - assert.NoError(t, err) - err = dataSyncService.checkReplica(defaultCollectionID) - assert.Error(t, err) - coll := dataSyncService.metaReplica.addCollection(defaultCollectionID, genTestCollectionSchema()) - assert.NotNil(t, coll) - }) - - t.Run("test cannot find tSafe", func(t *testing.T) { - coll, err := dataSyncService.metaReplica.getCollectionByID(defaultCollectionID) - assert.NoError(t, err) - coll.addVDeltaChannels([]Channel{defaultDeltaChannel}) - coll.addVChannels([]Channel{defaultDMLChannel}) - - dataSyncService.tSafeReplica.addTSafe(defaultDeltaChannel) - dataSyncService.tSafeReplica.addTSafe(defaultDMLChannel) - - dataSyncService.tSafeReplica.removeTSafe(defaultDeltaChannel) - err = dataSyncService.checkReplica(defaultCollectionID) - assert.Error(t, err) - - dataSyncService.tSafeReplica.removeTSafe(defaultDMLChannel) - err = dataSyncService.checkReplica(defaultCollectionID) - assert.Error(t, err) - - dataSyncService.tSafeReplica.addTSafe(defaultDeltaChannel) - dataSyncService.tSafeReplica.addTSafe(defaultDMLChannel) - }) -} - type DataSyncServiceSuite struct { suite.Suite factory dependency.Factory diff --git a/internal/querynode/impl.go b/internal/querynode/impl.go index f7d837a93d..031a3b19d3 100644 --- a/internal/querynode/impl.go +++ b/internal/querynode/impl.go @@ -20,8 +20,10 @@ import ( "context" "errors" "fmt" + "sort" "strconv" "sync" + "time" "github.com/golang/protobuf/proto" "go.uber.org/zap" @@ -308,37 +310,57 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, in *querypb.WatchDmC node: node, } - err := node.scheduler.queue.Enqueue(task) - if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - } - log.Warn(err.Error()) - return status, nil - } - log.Info("watchDmChannelsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID()), zap.Int64("replicaID", in.GetReplicaID())) - waitFunc := func() (*commonpb.Status, error) { - err = task.WaitToFinish() + startTs := time.Now() + log.Info("watchDmChannels init", zap.Int64("collectionID", in.CollectionID), + zap.String("channelName", in.Infos[0].GetChannelName()), + zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) + // currently we only support load one channel as a time + node.taskLock.RLock(strconv.FormatInt(in.Infos[0].CollectionID, 10)) + defer node.taskLock.RUnlock(strconv.FormatInt(in.Infos[0].CollectionID, 10)) + future := node.taskPool.Submit(func() (interface{}, error) { + log.Info("watchDmChannels start ", zap.Int64("collectionID", in.CollectionID), + zap.String("channelName", in.Infos[0].GetChannelName()), + zap.Duration("timeInQueue", time.Since(startTs))) + err := task.PreExecute(ctx) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error(), } - log.Warn(err.Error()) + log.Warn("failed to subscribe channel on preExecute ", zap.Error(err)) + return status, nil + } + + err = task.Execute(ctx) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Warn("failed to subscribe channel ", zap.Error(err)) + return status, nil + } + + err = task.PostExecute(ctx) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Warn("failed to unsubscribe channel on postExecute ", zap.Error(err)) return status, nil } sc, _ := node.ShardClusterService.getShardCluster(in.Infos[0].GetChannelName()) sc.SetupFirstVersion() - - log.Info("watchDmChannelsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) + log.Info("successfully watchDmChannelsTask", zap.Int64("collectionID", in.CollectionID), + zap.String("channelName", in.Infos[0].GetChannelName()), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil - } - - return waitFunc() + }) + ret, _ := future.Await() + return ret.(*commonpb.Status), nil } func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { @@ -375,13 +397,15 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC node: node, } + node.taskLock.Lock(strconv.FormatInt(dct.req.CollectionID, 10)) + defer node.taskLock.Unlock(strconv.FormatInt(dct.req.CollectionID, 10)) err := node.scheduler.queue.Enqueue(dct) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error(), } - log.Warn(err.Error()) + log.Warn("failed to enqueue subscribe channel task", zap.Error(err)) return status, nil } log.Info("unsubDmChannel(ReleaseCollection) enqueue done", zap.Int64("collectionID", req.GetCollectionID())) @@ -389,7 +413,7 @@ func (node *QueryNode) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmC func() { err = dct.WaitToFinish() if err != nil { - log.Warn(err.Error()) + log.Warn("failed to do subscribe channel task successfully", zap.Error(err)) return } log.Info("unsubDmChannel(ReleaseCollection) WaitToFinish done", zap.Int64("collectionID", req.GetCollectionID())) @@ -439,35 +463,66 @@ func (node *QueryNode) LoadSegments(ctx context.Context, in *querypb.LoadSegment for _, info := range in.Infos { segmentIDs = append(segmentIDs, info.SegmentID) } - err := node.scheduler.queue.Enqueue(task) - if err != nil { - status := &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - } - log.Warn(err.Error()) - return status, nil + sort.SliceStable(segmentIDs, func(i, j int) bool { + return segmentIDs[i] < segmentIDs[j] + }) + + startTs := time.Now() + log.Info("loadSegmentsTask init", zap.Int64("collectionID", in.CollectionID), + zap.Int64s("segmentIDs", segmentIDs), + zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) + + node.taskLock.RLock(strconv.FormatInt(in.CollectionID, 10)) + for _, segmentID := range segmentIDs { + node.taskLock.Lock(strconv.FormatInt(segmentID, 10)) } - log.Info("loadSegmentsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) - - waitFunc := func() (*commonpb.Status, error) { - err = task.WaitToFinish() + // release all task locks + defer func() { + node.taskLock.RUnlock(strconv.FormatInt(in.CollectionID, 10)) + for _, id := range segmentIDs { + node.taskLock.Unlock(strconv.FormatInt(id, 10)) + } + }() + future := node.taskPool.Submit(func() (interface{}, error) { + log.Info("loadSegmentsTask start ", zap.Int64("collectionID", in.CollectionID), + zap.Int64s("segmentIDs", segmentIDs), + zap.Duration("timeInQueue", time.Since(startTs))) + err := task.PreExecute(ctx) if err != nil { status := &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, Reason: err.Error(), } - log.Warn(err.Error()) + log.Warn("failed to load segments on preExecute ", zap.Error(err)) return status, nil } - log.Info("loadSegmentsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) + err = task.Execute(ctx) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Warn("failed to load segment", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Error(err)) + return status, nil + } + + err = task.PostExecute(ctx) + if err != nil { + status := &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + } + log.Warn("failed to load segments on postExecute ", zap.Error(err)) + return status, nil + } + log.Info("loadSegmentsTask done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", segmentIDs), zap.Int64("nodeID", Params.QueryNodeCfg.GetNodeID())) return &commonpb.Status{ ErrorCode: commonpb.ErrorCode_Success, }, nil - } - - return waitFunc() + }) + ret, _ := future.Await() + return ret.(*commonpb.Status), nil } // ReleaseCollection clears all data related to this collection on the querynode @@ -490,6 +545,8 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas node: node, } + node.taskLock.Lock(strconv.FormatInt(dct.req.CollectionID, 10)) + defer node.taskLock.Unlock(strconv.FormatInt(dct.req.CollectionID, 10)) err := node.scheduler.queue.Enqueue(dct) if err != nil { status := &commonpb.Status{ @@ -536,6 +593,8 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas node: node, } + node.taskLock.Lock(strconv.FormatInt(dct.req.CollectionID, 10)) + defer node.taskLock.Unlock(strconv.FormatInt(dct.req.CollectionID, 10)) err := node.scheduler.queue.Enqueue(dct) if err != nil { status := &commonpb.Status{ @@ -587,6 +646,23 @@ func (node *QueryNode) ReleaseSegments(ctx context.Context, in *querypb.ReleaseS return node.TransferRelease(ctx, in) } + log.Info("start to release segments", zap.Int64("collectionID", in.CollectionID), zap.Int64s("segmentIDs", in.SegmentIDs)) + node.taskLock.RLock(strconv.FormatInt(in.CollectionID, 10)) + sort.SliceStable(in.SegmentIDs, func(i, j int) bool { + return in.SegmentIDs[i] < in.SegmentIDs[j] + }) + + for _, segmentID := range in.SegmentIDs { + node.taskLock.Lock(strconv.FormatInt(segmentID, 10)) + } + + // release all task locks + defer func() { + node.taskLock.RUnlock(strconv.FormatInt(in.CollectionID, 10)) + for _, id := range in.SegmentIDs { + node.taskLock.Unlock(strconv.FormatInt(id, 10)) + } + }() for _, id := range in.SegmentIDs { switch in.GetScope() { case querypb.DataScope_Streaming: diff --git a/internal/querynode/log_segment_task.go b/internal/querynode/log_segment_task.go new file mode 100644 index 0000000000..fd24762c2c --- /dev/null +++ b/internal/querynode/log_segment_task.go @@ -0,0 +1,226 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querynode + +import ( + "context" + "fmt" + + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "github.com/milvus-io/milvus/internal/log" + queryPb "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/samber/lo" +) + +type loadSegmentsTask struct { + baseTask + req *queryPb.LoadSegmentsRequest + node *QueryNode +} + +// loadSegmentsTask +func (l *loadSegmentsTask) PreExecute(ctx context.Context) error { + log.Info("LoadSegmentTask PreExecute start", zap.Int64("msgID", l.req.Base.MsgID)) + var err error + // init meta + collectionID := l.req.GetCollectionID() + l.node.metaReplica.addCollection(collectionID, l.req.GetSchema()) + for _, partitionID := range l.req.GetLoadMeta().GetPartitionIDs() { + err = l.node.metaReplica.addPartition(collectionID, partitionID) + if err != nil { + return err + } + } + + // filter segments that are already loaded in this querynode + var filteredInfos []*queryPb.SegmentLoadInfo + for _, info := range l.req.Infos { + has, err := l.node.metaReplica.hasSegment(info.SegmentID, segmentTypeSealed) + if err != nil { + return err + } + if !has { + filteredInfos = append(filteredInfos, info) + } else { + log.Info("ignore segment that is already loaded", zap.Int64("collectionID", info.CollectionID), zap.Int64("segmentID", info.SegmentID)) + } + } + l.req.Infos = filteredInfos + log.Info("LoadSegmentTask PreExecute done", zap.Int64("msgID", l.req.Base.MsgID)) + return nil +} + +func (l *loadSegmentsTask) Execute(ctx context.Context) error { + log.Info("LoadSegmentTask Execute start", zap.Int64("msgID", l.req.Base.MsgID)) + + if len(l.req.Infos) == 0 { + log.Info("all segments loaded", zap.Int64("msgID", l.req.GetBase().GetMsgID())) + return nil + } + + segmentIDs := lo.Map(l.req.Infos, func(info *queryPb.SegmentLoadInfo, idx int) UniqueID { return info.SegmentID }) + l.node.metaReplica.addSegmentsLoadingList(segmentIDs) + defer l.node.metaReplica.removeSegmentsLoadingList(segmentIDs) + err := l.node.loader.LoadSegment(l.ctx, l.req, segmentTypeSealed) + if err != nil { + log.Warn("failed to load segment", zap.Int64("collectionID", l.req.CollectionID), + zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) + return err + } + vchanName := make([]string, 0) + for _, deltaPosition := range l.req.DeltaPositions { + vchanName = append(vchanName, deltaPosition.ChannelName) + } + + // TODO delta channel need to released 1. if other watchDeltaChannel fail 2. when segment release + err = l.watchDeltaChannel(vchanName) + if err != nil { + // roll back + for _, segment := range l.req.Infos { + l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) + } + log.Warn("failed to watch Delta channel while load segment", zap.Int64("collectionID", l.req.CollectionID), + zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) + return err + } + + runningGroup, groupCtx := errgroup.WithContext(l.ctx) + for _, deltaPosition := range l.req.DeltaPositions { + pos := deltaPosition + runningGroup.Go(func() error { + // reload data from dml channel + return l.node.loader.FromDmlCPLoadDelete(groupCtx, l.req.CollectionID, pos) + }) + } + err = runningGroup.Wait() + if err != nil { + for _, segment := range l.req.Infos { + l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) + } + log.Warn("failed to load delete data while load segment", zap.Int64("collectionID", l.req.CollectionID), + zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) + return err + } + + log.Info("LoadSegmentTask Execute done", zap.Int64("collectionID", l.req.CollectionID), + zap.Int64("replicaID", l.req.ReplicaID), zap.Int64("msgID", l.req.Base.MsgID)) + return nil +} + +// internal helper function to subscribe delta channel +func (l *loadSegmentsTask) watchDeltaChannel(vchanName []string) error { + collectionID := l.req.CollectionID + var vDeltaChannels []string + VPDeltaChannels := make(map[string]string) + for _, v := range vchanName { + dc, err := funcutil.ConvertChannelName(v, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta) + if err != nil { + log.Warn("watchDeltaChannels, failed to convert deltaChannel from dmlChannel", zap.String("DmlChannel", v), zap.Error(err)) + return err + } + p := funcutil.ToPhysicalChannel(dc) + vDeltaChannels = append(vDeltaChannels, dc) + VPDeltaChannels[dc] = p + } + log.Info("Starting WatchDeltaChannels ...", + zap.Int64("collectionID", collectionID), + zap.Any("channels", VPDeltaChannels), + ) + + coll, err := l.node.metaReplica.getCollectionByID(collectionID) + if err != nil { + return err + } + + // filter out duplicated channels + vDeltaChannels = coll.AddVDeltaChannels(vDeltaChannels, VPDeltaChannels) + defer func() { + if err != nil { + for _, vDeltaChannel := range vDeltaChannels { + coll.removeVDeltaChannel(vDeltaChannel) + } + } + }() + + if len(vDeltaChannels) == 0 { + log.Warn("all delta channels has be added before, ignore watch delta requests") + return nil + } + + channel2FlowGraph, err := l.node.dataSyncService.addFlowGraphsForDeltaChannels(collectionID, vDeltaChannels) + if err != nil { + log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err)) + return err + } + consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID()) + + // channels as consumer + for channel, fg := range channel2FlowGraph { + pchannel := VPDeltaChannels[channel] + // use pChannel to consume + err = fg.consumeFlowGraphFromLatest(pchannel, consumeSubName) + if err != nil { + log.Error("msgStream as consumer failed for deltaChannels", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels)) + break + } + } + + if err != nil { + log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err)) + for _, fg := range channel2FlowGraph { + fg.flowGraph.Close() + } + gcChannels := make([]Channel, 0) + for channel := range channel2FlowGraph { + gcChannels = append(gcChannels, channel) + } + l.node.dataSyncService.removeFlowGraphsByDeltaChannels(gcChannels) + return err + } + + log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels)) + + // create tSafe + for _, channel := range vDeltaChannels { + l.node.tSafeReplica.addTSafe(channel) + } + + // add tsafe watch in query shard if exists, we find no way to handle it if query shard not exist + for _, channel := range vDeltaChannels { + dmlChannel, err := funcutil.ConvertChannelName(channel, Params.CommonCfg.RootCoordDelta, Params.CommonCfg.RootCoordDml) + if err != nil { + log.Error("failed to convert delta channel to dml", zap.String("channel", channel), zap.Error(err)) + panic(err) + } + err = l.node.queryShardService.addQueryShard(collectionID, dmlChannel, l.req.GetReplicaID()) + if err != nil { + log.Error("failed to add shard Service to query shard", zap.String("channel", channel), zap.Error(err)) + panic(err) + } + } + + // start flow graphs + for _, fg := range channel2FlowGraph { + fg.flowGraph.Start() + } + + log.Info("WatchDeltaChannels done", zap.Int64("collectionID", collectionID), zap.String("ChannelIDs", fmt.Sprintln(vDeltaChannels))) + return nil +} diff --git a/internal/querynode/meta_replica.go b/internal/querynode/meta_replica.go index 012c97e10c..c230fe718b 100644 --- a/internal/querynode/meta_replica.go +++ b/internal/querynode/meta_replica.go @@ -417,9 +417,8 @@ func (replica *metaReplica) addPartitionPrivate(collection *Collection, partitio collection.addPartitionID(partitionID) var newPartition = newPartition(collection.ID(), partitionID) replica.partitions[partitionID] = newPartition + metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(len(replica.partitions))) } - - metrics.QueryNodeNumPartitions.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Set(float64(len(replica.partitions))) return nil } diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 02f5af7b8b..64a1475810 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -49,7 +49,9 @@ import ( "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/indexcgowrapper" + "github.com/milvus-io/milvus/internal/util/lock" "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/panjf2000/ants/v2" ) // ---------- unittest util functions ---------- @@ -1262,13 +1264,10 @@ func genSimpleReplicaWithSealSegment(ctx context.Context) (ReplicaInterface, err if err != nil { return nil, err } - col, err := r.getCollectionByID(defaultCollectionID) + _, err = r.getCollectionByID(defaultCollectionID) if err != nil { return nil, err } - col.addVChannels([]Channel{ - defaultDeltaChannel, - }) return r, nil } @@ -1661,6 +1660,12 @@ func genSimpleQueryNodeWithMQFactory(ctx context.Context, fac dependency.Factory node.etcdCli = etcdCli node.initSession() + node.taskPool, err = concurrency.NewPool(2, ants.WithPreAlloc(true)) + if err != nil { + log.Error("QueryNode init channel pool failed", zap.Error(err)) + return nil, err + } + node.taskLock = lock.NewKeyLock() etcdKV := etcdkv.NewEtcdKV(etcdCli, Params.EtcdCfg.MetaRootPath) node.etcdKV = etcdKV diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 8483b3b3fc..f6b61ccb7a 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -54,6 +54,7 @@ import ( "github.com/milvus-io/milvus/internal/util/concurrency" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/initcore" + "github.com/milvus-io/milvus/internal/util/lock" "github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/paramtable" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -120,6 +121,10 @@ type QueryNode struct { // cgoPool is the worker pool to control concurrency of cgo call cgoPool *concurrency.Pool + // pool for load/release channel + taskPool *concurrency.Pool + // lock to avoid same chanel/channel run multiple times + taskLock *lock.KeyLock } // NewQueryNode will return a QueryNode with abnormal state. @@ -258,6 +263,15 @@ func (node *QueryNode) Init() error { return } + node.taskPool, err = concurrency.NewPool(cpuNum, ants.WithPreAlloc(true)) + if err != nil { + log.Error("QueryNode init channel pool failed", zap.Error(err)) + initError = err + return + } + + node.taskLock = lock.NewKeyLock() + // ensure every cgopool go routine is locked with a OS thread // so openmp in knowhere won't create too much request sig := make(chan struct{}) diff --git a/internal/querynode/query_shard_service.go b/internal/querynode/query_shard_service.go index e65fd79497..dd06f388fb 100644 --- a/internal/querynode/query_shard_service.go +++ b/internal/querynode/query_shard_service.go @@ -28,6 +28,7 @@ import ( "go.uber.org/zap" ) +// TODO, remove queryShardService, it's not used any more. type queryShardService struct { ctx context.Context cancel context.CancelFunc diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 8cb3f7c83f..81a0f172ca 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -20,12 +20,13 @@ import ( "context" "errors" "fmt" + "math/rand" "path" "runtime" "runtime/debug" "strconv" + "time" - "github.com/panjf2000/ants/v2" "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/commonpb" @@ -46,6 +47,8 @@ import ( "github.com/milvus-io/milvus/internal/util/hardware" "github.com/milvus-io/milvus/internal/util/indexparamcheck" "github.com/milvus-io/milvus/internal/util/timerecord" + "github.com/milvus-io/milvus/internal/util/tsoutil" + "github.com/panjf2000/ants/v2" ) const ( @@ -703,7 +706,7 @@ func (loader *segmentLoader) loadDeltaLogs(ctx context.Context, segment *Segment } func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collectionID int64, position *internalpb.MsgPosition) error { - log.Info("from dml check point load delete", zap.Any("position", position)) + startTs := time.Now() stream, err := loader.factory.NewMsgStream(ctx) if err != nil { return err @@ -717,7 +720,12 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection pChannelName := funcutil.ToPhysicalChannel(position.ChannelName) position.ChannelName = pChannelName - stream.AsConsumer([]string{pChannelName}, fmt.Sprintf("querynode-%d-%d", Params.QueryNodeCfg.GetNodeID(), collectionID), mqwrapper.SubscriptionPositionUnknown) + ts, _ := tsoutil.ParseTS(position.Timestamp) + + // Random the subname in case we trying to load same delta at the same time + subName := fmt.Sprintf("querynode-delta-loader-%d-%d-%d", Params.QueryNodeCfg.GetNodeID(), collectionID, rand.Int()) + log.Info("from dml check point load delete", zap.Any("position", position), zap.String("subName", subName), zap.Time("positionTs", ts)) + stream.AsConsumer([]string{pChannelName}, subName, mqwrapper.SubscriptionPositionUnknown) // make sure seek position is earlier than lastMsgID, err := stream.GetLatestMsgID(pChannelName) if err != nil { @@ -730,7 +738,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection } if reachLatest || lastMsgID.AtEarliestPosition() { - log.Info("there is no more delta msg", zap.Int64("Collection ID", collectionID), zap.String("channel", pChannelName)) + log.Info("there is no more delta msg", zap.Int64("collectionID", collectionID), zap.String("channel", pChannelName)) return nil } @@ -748,7 +756,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection } log.Info("start read delta msg from seek position to last position", - zap.Int64("Collection ID", collectionID), zap.String("channel", pChannelName), zap.Any("seek pos", position), zap.Any("last msg", lastMsgID)) + zap.Int64("collectionID", collectionID), zap.String("channel", pChannelName), zap.Any("seekPos", position), zap.Any("lastMsg", lastMsgID)) hasMore := true for hasMore { select { @@ -791,7 +799,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection ret, err := lastMsgID.LessOrEqualThan(tsMsg.Position().MsgID) if err != nil { log.Warn("check whether current MsgID less than last MsgID failed", - zap.Int64("Collection ID", collectionID), zap.String("channel", pChannelName), zap.Error(err)) + zap.Int64("collectionID", collectionID), zap.String("channel", pChannelName), zap.Error(err)) return err } @@ -803,8 +811,8 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection } } - log.Info("All data has been read, there is no more data", zap.Int64("Collection ID", collectionID), - zap.String("channel", pChannelName), zap.Any("msg id", position.GetMsgID())) + log.Info("All data has been read, there is no more data", zap.Int64("collectionID", collectionID), + zap.String("channel", pChannelName), zap.Any("msgID", position.GetMsgID())) for segmentID, pks := range delData.deleteIDs { segment, err := loader.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) if err != nil { @@ -821,7 +829,7 @@ func (loader *segmentLoader) FromDmlCPLoadDelete(ctx context.Context, collection } } - log.Info("from dml check point load done", zap.Any("msg id", position.GetMsgID())) + log.Info("from dml check point load done", zap.String("subName", subName), zap.Any("timeTake", time.Since(startTs))) return nil } diff --git a/internal/querynode/task.go b/internal/querynode/task.go index 861ae413fa..1aad4eba87 100644 --- a/internal/querynode/task.go +++ b/internal/querynode/task.go @@ -19,23 +19,13 @@ package querynode import ( "context" "errors" - "fmt" - "math" "runtime/debug" "go.uber.org/zap" - "golang.org/x/sync/errgroup" - - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/internalpb" queryPb "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/util/commonpbutil" - "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/typeutil" - "github.com/samber/lo" ) type task interface { @@ -90,18 +80,6 @@ func (b *baseTask) Notify(err error) { b.done <- err } -type watchDmChannelsTask struct { - baseTask - req *queryPb.WatchDmChannelsRequest - node *QueryNode -} - -type loadSegmentsTask struct { - baseTask - req *queryPb.LoadSegmentsRequest - node *QueryNode -} - type releaseCollectionTask struct { baseTask req *queryPb.ReleaseCollectionRequest @@ -114,493 +92,6 @@ type releasePartitionsTask struct { node *QueryNode } -// watchDmChannelsTask -func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { - collectionID := w.req.CollectionID - partitionIDs := w.req.GetPartitionIDs() - - lType := w.req.GetLoadMeta().GetLoadType() - if lType == queryPb.LoadType_UnKnownType { - // if no partitionID is specified, load type is load collection - if len(partitionIDs) != 0 { - lType = queryPb.LoadType_LoadPartition - } else { - lType = queryPb.LoadType_LoadCollection - } - } - - // get all vChannels - var vChannels, pChannels []Channel - VPChannels := make(map[string]string) // map[vChannel]pChannel - for _, info := range w.req.Infos { - v := info.ChannelName - p := funcutil.ToPhysicalChannel(info.ChannelName) - vChannels = append(vChannels, v) - pChannels = append(pChannels, p) - VPChannels[v] = p - } - - if len(VPChannels) != len(vChannels) { - return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID)) - } - - log.Info("Starting WatchDmChannels ...", - zap.String("collectionName", w.req.Schema.Name), - zap.Int64("collectionID", collectionID), - zap.Int64("replicaID", w.req.GetReplicaID()), - zap.Any("load type", lType), - zap.Strings("vChannels", vChannels), - zap.Strings("pChannels", pChannels), - ) - - // init collection meta - coll := w.node.metaReplica.addCollection(collectionID, w.req.Schema) - - loadedChannelCounter := 0 - for _, toLoadChannel := range vChannels { - for _, loadedChannel := range coll.vChannels { - if toLoadChannel == loadedChannel { - loadedChannelCounter++ - break - } - } - } - - // check if all channels has been loaded, if YES, should do nothing and return - // in case of query coord trigger same watchDmChannelTask on multi - if len(vChannels) == loadedChannelCounter { - log.Warn("All channel has been loaded, skip this watchDmChannelsTask") - return nil - } - - //add shard cluster - for _, vchannel := range vChannels { - w.node.ShardClusterService.addShardCluster(w.req.GetCollectionID(), w.req.GetReplicaID(), vchannel) - } - - defer func() { - if err != nil { - for _, vchannel := range vChannels { - w.node.ShardClusterService.releaseShardCluster(vchannel) - } - } - }() - - // load growing segments - unFlushedSegments := make([]*queryPb.SegmentLoadInfo, 0) - unFlushedSegmentIDs := make([]UniqueID, 0) - for _, info := range w.req.Infos { - for _, ufInfoID := range info.GetUnflushedSegmentIds() { - // unFlushed segment may not have binLogs, skip loading - ufInfo := w.req.GetSegmentInfos()[ufInfoID] - if ufInfo == nil { - log.Warn("an unflushed segment is not found in segment infos", zap.Int64("segment ID", ufInfoID)) - continue - } - if len(ufInfo.GetBinlogs()) > 0 { - unFlushedSegments = append(unFlushedSegments, &queryPb.SegmentLoadInfo{ - SegmentID: ufInfo.ID, - PartitionID: ufInfo.PartitionID, - CollectionID: ufInfo.CollectionID, - BinlogPaths: ufInfo.Binlogs, - NumOfRows: ufInfo.NumOfRows, - Statslogs: ufInfo.Statslogs, - Deltalogs: ufInfo.Deltalogs, - InsertChannel: ufInfo.InsertChannel, - }) - unFlushedSegmentIDs = append(unFlushedSegmentIDs, ufInfo.GetID()) - } else { - log.Info("skip segment which binlog is empty", zap.Int64("segmentID", ufInfo.ID)) - } - } - } - req := &queryPb.LoadSegmentsRequest{ - Base: commonpbutil.NewMsgBase( - commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments), - commonpbutil.WithMsgID(w.req.Base.MsgID), // use parent task's msgID - ), - Infos: unFlushedSegments, - CollectionID: collectionID, - Schema: w.req.GetSchema(), - LoadMeta: w.req.GetLoadMeta(), - } - - // update partition info from unFlushedSegments and loadMeta - for _, info := range req.Infos { - err = w.node.metaReplica.addPartition(collectionID, info.PartitionID) - if err != nil { - return err - } - } - for _, partitionID := range req.GetLoadMeta().GetPartitionIDs() { - err = w.node.metaReplica.addPartition(collectionID, partitionID) - if err != nil { - return err - } - } - - log.Info("loading growing segments in WatchDmChannels...", - zap.Int64("collectionID", collectionID), - zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs), - ) - err = w.node.loader.LoadSegment(w.ctx, req, segmentTypeGrowing) - if err != nil { - log.Warn(err.Error()) - return err - } - log.Info("successfully load growing segments done in WatchDmChannels", - zap.Int64("collectionID", collectionID), - zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs), - ) - - // remove growing segment if watch dmChannels failed - defer func() { - if err != nil { - for _, segmentID := range unFlushedSegmentIDs { - w.node.metaReplica.removeSegment(segmentID, segmentTypeGrowing) - } - } - }() - - // So far, we don't support to enable each node with two different channel - consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID()) - - // group channels by to seeking or consuming - channel2SeekPosition := make(map[string]*internalpb.MsgPosition) - - // for channel with no position - channel2AsConsumerPosition := make(map[string]*internalpb.MsgPosition) - for _, info := range w.req.Infos { - if info.SeekPosition == nil || len(info.SeekPosition.MsgID) == 0 { - channel2AsConsumerPosition[info.ChannelName] = info.SeekPosition - continue - } - info.SeekPosition.MsgGroup = consumeSubName - channel2SeekPosition[info.ChannelName] = info.SeekPosition - } - log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID)) - - // add excluded segments for unFlushed segments, - // unFlushed segments before check point should be filtered out. - unFlushedCheckPointInfos := make([]*datapb.SegmentInfo, 0) - for _, info := range w.req.Infos { - for _, ufsID := range info.GetUnflushedSegmentIds() { - unFlushedCheckPointInfos = append(unFlushedCheckPointInfos, w.req.SegmentInfos[ufsID]) - } - } - w.node.metaReplica.addExcludedSegments(collectionID, unFlushedCheckPointInfos) - unflushedSegmentIDs := make([]UniqueID, len(unFlushedCheckPointInfos)) - for i, segInfo := range unFlushedCheckPointInfos { - unflushedSegmentIDs[i] = segInfo.GetID() - } - log.Info("watchDMChannel, add check points info for unflushed segments done", - zap.Int64("collectionID", collectionID), - zap.Any("unflushedSegmentIDs", unflushedSegmentIDs), - ) - - // add excluded segments for flushed segments, - // flushed segments with later check point than seekPosition should be filtered out. - flushedCheckPointInfos := make([]*datapb.SegmentInfo, 0) - for _, info := range w.req.Infos { - for _, flushedSegmentID := range info.GetFlushedSegmentIds() { - flushedSegment := w.req.SegmentInfos[flushedSegmentID] - for _, position := range channel2SeekPosition { - if flushedSegment.DmlPosition != nil && - flushedSegment.DmlPosition.ChannelName == position.ChannelName && - flushedSegment.DmlPosition.Timestamp > position.Timestamp { - flushedCheckPointInfos = append(flushedCheckPointInfos, flushedSegment) - } - } - } - } - w.node.metaReplica.addExcludedSegments(collectionID, flushedCheckPointInfos) - flushedSegmentIDs := make([]UniqueID, len(flushedCheckPointInfos)) - for i, segInfo := range flushedCheckPointInfos { - flushedSegmentIDs[i] = segInfo.GetID() - } - log.Info("watchDMChannel, add check points info for flushed segments done", - zap.Int64("collectionID", collectionID), - zap.Any("flushedSegmentIDs", flushedSegmentIDs), - ) - - // add excluded segments for dropped segments, - // exclude all msgs with dropped segment id - // DO NOT refer to dropped segment info, see issue https://github.com/milvus-io/milvus/issues/19704 - var droppedCheckPointInfos []*datapb.SegmentInfo - for _, info := range w.req.Infos { - for _, droppedSegmentID := range info.GetDroppedSegmentIds() { - droppedCheckPointInfos = append(droppedCheckPointInfos, &datapb.SegmentInfo{ - ID: droppedSegmentID, - CollectionID: collectionID, - InsertChannel: info.GetChannelName(), - DmlPosition: &internalpb.MsgPosition{ - ChannelName: info.GetChannelName(), - Timestamp: math.MaxUint64, - }, - }) - } - } - w.node.metaReplica.addExcludedSegments(collectionID, droppedCheckPointInfos) - droppedSegmentIDs := make([]UniqueID, len(droppedCheckPointInfos)) - for i, segInfo := range droppedCheckPointInfos { - droppedSegmentIDs[i] = segInfo.GetID() - } - log.Info("watchDMChannel, add check points info for dropped segments done", - zap.Int64("collectionID", collectionID), - zap.Any("droppedSegmentIDs", droppedSegmentIDs), - ) - - // add flow graph - channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, vChannels) - if err != nil { - log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err)) - return err - } - log.Info("Query node add DML flow graphs", zap.Int64("collectionID", collectionID), zap.Any("channels", vChannels)) - - // channels as consumer - for channel, fg := range channel2FlowGraph { - if _, ok := channel2AsConsumerPosition[channel]; ok { - // use pChannel to consume - err = fg.consumeFlowGraph(VPChannels[channel], consumeSubName) - if err != nil { - log.Error("msgStream as consumer failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel)) - break - } - } - - if pos, ok := channel2SeekPosition[channel]; ok { - pos.MsgGroup = consumeSubName - // use pChannel to seek - pos.ChannelName = VPChannels[channel] - err = fg.consumeFlowGraphFromPosition(pos) - if err != nil { - log.Error("msgStream seek failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel)) - break - } - } - } - - if err != nil { - log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err)) - for _, fg := range channel2FlowGraph { - fg.flowGraph.Close() - } - gcChannels := make([]Channel, 0) - for channel := range channel2FlowGraph { - gcChannels = append(gcChannels, channel) - } - w.node.dataSyncService.removeFlowGraphsByDMLChannels(gcChannels) - return err - } - - log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) - - coll.addVChannels(vChannels) - coll.addPChannels(pChannels) - coll.setLoadType(lType) - - log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) - - // create tSafe - for _, channel := range vChannels { - w.node.tSafeReplica.addTSafe(channel) - } - - // add tsafe watch in query shard if exists - for _, dmlChannel := range vChannels { - w.node.queryShardService.addQueryShard(collectionID, dmlChannel, w.req.GetReplicaID()) - } - - // start flow graphs - for _, fg := range channel2FlowGraph { - fg.flowGraph.Start() - } - - log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) - return nil -} - -// internal helper function to subscribe delta channel -func (l *loadSegmentsTask) watchDeltaChannel(vchanName []string) error { - collectionID := l.req.CollectionID - var vDeltaChannels, pDeltaChannels []string - VPDeltaChannels := make(map[string]string) - for _, v := range vchanName { - dc, err := funcutil.ConvertChannelName(v, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta) - if err != nil { - log.Warn("watchDeltaChannels, failed to convert deltaChannel from dmlChannel", zap.String("DmlChannel", v), zap.Error(err)) - return err - } - p := funcutil.ToPhysicalChannel(dc) - vDeltaChannels = append(vDeltaChannels, dc) - pDeltaChannels = append(pDeltaChannels, p) - VPDeltaChannels[dc] = p - } - log.Info("Starting WatchDeltaChannels ...", - zap.Int64("collectionID", collectionID), - zap.Any("channels", VPDeltaChannels), - ) - - coll, err := l.node.metaReplica.getCollectionByID(collectionID) - if err != nil { - return err - } - - channel2FlowGraph, err := l.node.dataSyncService.addFlowGraphsForDeltaChannels(collectionID, vDeltaChannels) - if err != nil { - log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err)) - return err - } - consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID()) - - // channels as consumer - for channel, fg := range channel2FlowGraph { - pchannel := VPDeltaChannels[channel] - // use pChannel to consume - err = fg.consumeFlowGraphFromLatest(pchannel, consumeSubName) - if err != nil { - log.Error("msgStream as consumer failed for deltaChannels", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels)) - break - } - } - - if err != nil { - log.Warn("watchDeltaChannel, add flowGraph for deltaChannel failed", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels), zap.Error(err)) - for _, fg := range channel2FlowGraph { - fg.flowGraph.Close() - } - gcChannels := make([]Channel, 0) - for channel := range channel2FlowGraph { - gcChannels = append(gcChannels, channel) - } - l.node.dataSyncService.removeFlowGraphsByDeltaChannels(gcChannels) - return err - } - - log.Info("watchDeltaChannel, add flowGraph for deltaChannel success", zap.Int64("collectionID", collectionID), zap.Strings("vDeltaChannels", vDeltaChannels)) - - //set collection replica - coll.addVDeltaChannels(vDeltaChannels) - coll.addPDeltaChannels(pDeltaChannels) - - // create tSafe - for _, channel := range vDeltaChannels { - l.node.tSafeReplica.addTSafe(channel) - } - - // add tsafe watch in query shard if exists, we find no way to handle it if query shard not exist - for _, channel := range vDeltaChannels { - dmlChannel, err := funcutil.ConvertChannelName(channel, Params.CommonCfg.RootCoordDelta, Params.CommonCfg.RootCoordDml) - if err != nil { - log.Error("failed to convert delta channel to dml", zap.String("channel", channel), zap.Error(err)) - panic(err) - } - err = l.node.queryShardService.addQueryShard(collectionID, dmlChannel, l.req.GetReplicaID()) - if err != nil { - log.Error("failed to add shard Service to query shard", zap.String("channel", channel), zap.Error(err)) - panic(err) - } - } - - // start flow graphs - for _, fg := range channel2FlowGraph { - fg.flowGraph.Start() - } - - log.Info("WatchDeltaChannels done", zap.Int64("collectionID", collectionID), zap.String("ChannelIDs", fmt.Sprintln(vDeltaChannels))) - return nil -} - -// loadSegmentsTask -func (l *loadSegmentsTask) PreExecute(ctx context.Context) error { - log.Info("LoadSegmentTask PreExecute start", zap.Int64("msgID", l.req.Base.MsgID)) - var err error - // init meta - collectionID := l.req.GetCollectionID() - l.node.metaReplica.addCollection(collectionID, l.req.GetSchema()) - for _, partitionID := range l.req.GetLoadMeta().GetPartitionIDs() { - err = l.node.metaReplica.addPartition(collectionID, partitionID) - if err != nil { - return err - } - } - - // filter segments that are already loaded in this querynode - var filteredInfos []*queryPb.SegmentLoadInfo - for _, info := range l.req.Infos { - has, err := l.node.metaReplica.hasSegment(info.SegmentID, segmentTypeSealed) - if err != nil { - return err - } - if !has { - filteredInfos = append(filteredInfos, info) - } else { - log.Debug("ignore segment that is already loaded", zap.Int64("collectionID", info.SegmentID), zap.Int64("segmentID", info.SegmentID)) - } - } - l.req.Infos = filteredInfos - log.Info("LoadSegmentTask PreExecute done", zap.Int64("msgID", l.req.Base.MsgID)) - return nil -} - -func (l *loadSegmentsTask) Execute(ctx context.Context) error { - log.Info("LoadSegmentTask Execute start", zap.Int64("msgID", l.req.Base.MsgID)) - - if len(l.req.Infos) == 0 { - log.Info("all segments loaded", - zap.Int64("msgID", l.req.GetBase().GetMsgID())) - return nil - } - - segmentIDs := lo.Map(l.req.Infos, func(info *queryPb.SegmentLoadInfo, idx int) UniqueID { return info.SegmentID }) - l.node.metaReplica.addSegmentsLoadingList(segmentIDs) - defer l.node.metaReplica.removeSegmentsLoadingList(segmentIDs) - err := l.node.loader.LoadSegment(l.ctx, l.req, segmentTypeSealed) - if err != nil { - log.Warn("failed to load segment", zap.Int64("collectionID", l.req.CollectionID), - zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) - return err - } - vchanName := make([]string, 0) - for _, deltaPosition := range l.req.DeltaPositions { - vchanName = append(vchanName, deltaPosition.ChannelName) - } - // TODO delta channel need to released 1. if other watchDeltaChannel fail 2. when segment release - err = l.watchDeltaChannel(vchanName) - if err != nil { - // roll back - for _, segment := range l.req.Infos { - l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) - } - log.Warn("failed to watch Delta channel while load segment", zap.Int64("collectionID", l.req.CollectionID), - zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) - return err - } - - runningGroup, groupCtx := errgroup.WithContext(l.ctx) - for _, deltaPosition := range l.req.DeltaPositions { - pos := deltaPosition - runningGroup.Go(func() error { - // reload data from dml channel - return l.node.loader.FromDmlCPLoadDelete(groupCtx, l.req.CollectionID, pos) - }) - } - err = runningGroup.Wait() - if err != nil { - for _, segment := range l.req.Infos { - l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) - } - log.Warn("failed to load delete data while load segment", zap.Int64("collectionID", l.req.CollectionID), - zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) - return err - } - - log.Info("LoadSegmentTask Execute done", zap.Int64("collectionID", l.req.CollectionID), - zap.Int64("replicaID", l.req.ReplicaID), zap.Int64("msgID", l.req.Base.MsgID)) - return nil -} - func (r *releaseCollectionTask) Execute(ctx context.Context) error { log.Info("Execute release collection task", zap.Any("collectionID", r.req.CollectionID)) diff --git a/internal/querynode/watch_dm_channels_task.go b/internal/querynode/watch_dm_channels_task.go new file mode 100644 index 0000000000..0843b6ba2e --- /dev/null +++ b/internal/querynode/watch_dm_channels_task.go @@ -0,0 +1,354 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querynode + +import ( + "context" + "errors" + "fmt" + "math" + + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + queryPb "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/commonpbutil" + "github.com/milvus-io/milvus/internal/util/funcutil" +) + +type watchDmChannelsTask struct { + baseTask + req *queryPb.WatchDmChannelsRequest + node *QueryNode +} + +// watchDmChannelsTask +func (w *watchDmChannelsTask) Execute(ctx context.Context) (err error) { + collectionID := w.req.CollectionID + partitionIDs := w.req.GetPartitionIDs() + + lType := w.req.GetLoadMeta().GetLoadType() + if lType == queryPb.LoadType_UnKnownType { + // if no partitionID is specified, load type is load collection + if len(partitionIDs) != 0 { + lType = queryPb.LoadType_LoadPartition + } else { + lType = queryPb.LoadType_LoadCollection + } + } + + // get all vChannels + var vChannels []Channel + VPChannels := make(map[string]string) // map[vChannel]pChannel + for _, info := range w.req.Infos { + v := info.ChannelName + p := funcutil.ToPhysicalChannel(info.ChannelName) + vChannels = append(vChannels, v) + VPChannels[v] = p + } + + if len(VPChannels) != len(vChannels) { + return errors.New("get physical channels failed, illegal channel length, collectionID = " + fmt.Sprintln(collectionID)) + } + + log.Info("Starting WatchDmChannels ...", + zap.String("collectionName", w.req.Schema.Name), + zap.Int64("collectionID", collectionID), + zap.Int64("replicaID", w.req.GetReplicaID()), + zap.String("load type", lType.String()), + zap.Strings("vChannels", vChannels), + ) + + // init collection meta + coll := w.node.metaReplica.addCollection(collectionID, w.req.Schema) + + // filter out the already exist channels + vChannels = coll.AddChannels(vChannels, VPChannels) + defer func() { + if err != nil { + for _, vChannel := range vChannels { + coll.removeVChannel(vChannel) + } + } + }() + + if len(vChannels) == 0 { + log.Warn("all channels has be added before, ignore watch dml requests") + return nil + } + + //add shard cluster + for _, vchannel := range vChannels { + w.node.ShardClusterService.addShardCluster(w.req.GetCollectionID(), w.req.GetReplicaID(), vchannel) + } + + defer func() { + if err != nil { + for _, vchannel := range vChannels { + w.node.ShardClusterService.releaseShardCluster(vchannel) + } + } + }() + + unFlushedSegmentIDs, err := w.LoadGrowingSegments(ctx, collectionID) + + // remove growing segment if watch dmChannels failed + defer func() { + if err != nil { + for _, segmentID := range unFlushedSegmentIDs { + w.node.metaReplica.removeSegment(segmentID, segmentTypeGrowing) + } + } + }() + + channel2FlowGraph, err := w.initFlowGraph(ctx, collectionID, vChannels, VPChannels) + if err != nil { + return err + } + + coll.setLoadType(lType) + + log.Info("watchDMChannel, init replica done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) + + // create tSafe + for _, channel := range vChannels { + w.node.tSafeReplica.addTSafe(channel) + } + + // add tsafe watch in query shard if exists + for _, dmlChannel := range vChannels { + w.node.queryShardService.addQueryShard(collectionID, dmlChannel, w.req.GetReplicaID()) + } + + // start flow graphs + for _, fg := range channel2FlowGraph { + fg.flowGraph.Start() + } + + log.Info("WatchDmChannels done", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) + return nil +} + +func (w *watchDmChannelsTask) LoadGrowingSegments(ctx context.Context, collectionID UniqueID) ([]UniqueID, error) { + // load growing segments + unFlushedSegments := make([]*queryPb.SegmentLoadInfo, 0) + unFlushedSegmentIDs := make([]UniqueID, 0) + for _, info := range w.req.Infos { + for _, ufInfoID := range info.GetUnflushedSegmentIds() { + // unFlushed segment may not have binLogs, skip loading + ufInfo := w.req.GetSegmentInfos()[ufInfoID] + if ufInfo == nil { + log.Warn("an unflushed segment is not found in segment infos", zap.Int64("segment ID", ufInfoID)) + continue + } + if len(ufInfo.GetBinlogs()) > 0 { + unFlushedSegments = append(unFlushedSegments, &queryPb.SegmentLoadInfo{ + SegmentID: ufInfo.ID, + PartitionID: ufInfo.PartitionID, + CollectionID: ufInfo.CollectionID, + BinlogPaths: ufInfo.Binlogs, + NumOfRows: ufInfo.NumOfRows, + Statslogs: ufInfo.Statslogs, + Deltalogs: ufInfo.Deltalogs, + InsertChannel: ufInfo.InsertChannel, + }) + unFlushedSegmentIDs = append(unFlushedSegmentIDs, ufInfo.GetID()) + } else { + log.Info("skip segment which binlog is empty", zap.Int64("segmentID", ufInfo.ID)) + } + } + } + req := &queryPb.LoadSegmentsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_LoadSegments), + commonpbutil.WithMsgID(w.req.Base.MsgID), // use parent task's msgID + ), + Infos: unFlushedSegments, + CollectionID: collectionID, + Schema: w.req.GetSchema(), + LoadMeta: w.req.GetLoadMeta(), + } + + // update partition info from unFlushedSegments and loadMeta + for _, info := range req.Infos { + err := w.node.metaReplica.addPartition(collectionID, info.PartitionID) + if err != nil { + return nil, err + } + } + for _, partitionID := range req.GetLoadMeta().GetPartitionIDs() { + err := w.node.metaReplica.addPartition(collectionID, partitionID) + if err != nil { + return nil, err + } + } + + log.Info("loading growing segments in WatchDmChannels...", + zap.Int64("collectionID", collectionID), + zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs), + ) + err := w.node.loader.LoadSegment(w.ctx, req, segmentTypeGrowing) + if err != nil { + log.Warn("failed to load segment", zap.Int64("collection", collectionID), zap.Error(err)) + return nil, err + } + log.Info("successfully load growing segments done in WatchDmChannels", + zap.Int64("collectionID", collectionID), + zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs), + ) + return unFlushedSegmentIDs, nil +} + +func (w *watchDmChannelsTask) initFlowGraph(ctx context.Context, collectionID UniqueID, vChannels []Channel, VPChannels map[string]string) (map[string]*queryNodeFlowGraph, error) { + // So far, we don't support to enable each node with two different channel + consumeSubName := funcutil.GenChannelSubName(Params.CommonCfg.QueryNodeSubName, collectionID, Params.QueryNodeCfg.GetNodeID()) + + // group channels by to seeking or consuming + channel2SeekPosition := make(map[string]*internalpb.MsgPosition) + + // for channel with no position + channel2AsConsumerPosition := make(map[string]*internalpb.MsgPosition) + for _, info := range w.req.Infos { + if info.SeekPosition == nil || len(info.SeekPosition.MsgID) == 0 { + channel2AsConsumerPosition[info.ChannelName] = info.SeekPosition + continue + } + info.SeekPosition.MsgGroup = consumeSubName + channel2SeekPosition[info.ChannelName] = info.SeekPosition + } + log.Info("watchDMChannel, group channels done", zap.Int64("collectionID", collectionID)) + + // add excluded segments for unFlushed segments, + // unFlushed segments before check point should be filtered out. + unFlushedCheckPointInfos := make([]*datapb.SegmentInfo, 0) + for _, info := range w.req.Infos { + for _, ufsID := range info.GetUnflushedSegmentIds() { + unFlushedCheckPointInfos = append(unFlushedCheckPointInfos, w.req.SegmentInfos[ufsID]) + } + } + w.node.metaReplica.addExcludedSegments(collectionID, unFlushedCheckPointInfos) + unflushedSegmentIDs := make([]UniqueID, len(unFlushedCheckPointInfos)) + for i, segInfo := range unFlushedCheckPointInfos { + unflushedSegmentIDs[i] = segInfo.GetID() + } + log.Info("watchDMChannel, add check points info for unflushed segments done", + zap.Int64("collectionID", collectionID), + zap.Any("unflushedSegmentIDs", unflushedSegmentIDs), + ) + + // add excluded segments for flushed segments, + // flushed segments with later check point than seekPosition should be filtered out. + flushedCheckPointInfos := make([]*datapb.SegmentInfo, 0) + for _, info := range w.req.Infos { + for _, flushedSegmentID := range info.GetFlushedSegmentIds() { + flushedSegment := w.req.SegmentInfos[flushedSegmentID] + for _, position := range channel2SeekPosition { + if flushedSegment.DmlPosition != nil && + flushedSegment.DmlPosition.ChannelName == position.ChannelName && + flushedSegment.DmlPosition.Timestamp > position.Timestamp { + flushedCheckPointInfos = append(flushedCheckPointInfos, flushedSegment) + } + } + } + } + w.node.metaReplica.addExcludedSegments(collectionID, flushedCheckPointInfos) + flushedSegmentIDs := make([]UniqueID, len(flushedCheckPointInfos)) + for i, segInfo := range flushedCheckPointInfos { + flushedSegmentIDs[i] = segInfo.GetID() + } + log.Info("watchDMChannel, add check points info for flushed segments done", + zap.Int64("collectionID", collectionID), + zap.Any("flushedSegmentIDs", flushedSegmentIDs), + ) + + // add excluded segments for dropped segments, + // exclude all msgs with dropped segment id + // DO NOT refer to dropped segment info, see issue https://github.com/milvus-io/milvus/issues/19704 + var droppedCheckPointInfos []*datapb.SegmentInfo + for _, info := range w.req.Infos { + for _, droppedSegmentID := range info.GetDroppedSegmentIds() { + droppedCheckPointInfos = append(droppedCheckPointInfos, &datapb.SegmentInfo{ + ID: droppedSegmentID, + CollectionID: collectionID, + InsertChannel: info.GetChannelName(), + DmlPosition: &internalpb.MsgPosition{ + ChannelName: info.GetChannelName(), + Timestamp: math.MaxUint64, + }, + }) + } + } + w.node.metaReplica.addExcludedSegments(collectionID, droppedCheckPointInfos) + droppedSegmentIDs := make([]UniqueID, len(droppedCheckPointInfos)) + for i, segInfo := range droppedCheckPointInfos { + droppedSegmentIDs[i] = segInfo.GetID() + } + log.Info("watchDMChannel, add check points info for dropped segments done", + zap.Int64("collectionID", collectionID), + zap.Any("droppedSegmentIDs", droppedSegmentIDs), + ) + + // add flow graph + channel2FlowGraph, err := w.node.dataSyncService.addFlowGraphsForDMLChannels(collectionID, vChannels) + if err != nil { + log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err)) + return nil, err + } + log.Info("Query node add DML flow graphs", zap.Int64("collectionID", collectionID), zap.Any("channels", vChannels)) + + // channels as consumer + for channel, fg := range channel2FlowGraph { + if _, ok := channel2AsConsumerPosition[channel]; ok { + // use pChannel to consume + err = fg.consumeFlowGraph(VPChannels[channel], consumeSubName) + if err != nil { + log.Error("msgStream as consumer failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel)) + break + } + } + + if pos, ok := channel2SeekPosition[channel]; ok { + pos.MsgGroup = consumeSubName + // use pChannel to seek + pos.ChannelName = VPChannels[channel] + err = fg.consumeFlowGraphFromPosition(pos) + if err != nil { + log.Error("msgStream seek failed for dmChannels", zap.Int64("collectionID", collectionID), zap.String("vChannel", channel)) + break + } + } + } + + if err != nil { + log.Warn("watchDMChannel, add flowGraph for dmChannels failed", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels), zap.Error(err)) + for _, fg := range channel2FlowGraph { + fg.flowGraph.Close() + } + gcChannels := make([]Channel, 0) + for channel := range channel2FlowGraph { + gcChannels = append(gcChannels, channel) + } + w.node.dataSyncService.removeFlowGraphsByDMLChannels(gcChannels) + return nil, err + } + + log.Info("watchDMChannel, add flowGraph for dmChannels success", zap.Int64("collectionID", collectionID), zap.Strings("vChannels", vChannels)) + return channel2FlowGraph, nil +} diff --git a/internal/util/lock/key_lock.go b/internal/util/lock/key_lock.go new file mode 100644 index 0000000000..95ef06bbb3 --- /dev/null +++ b/internal/util/lock/key_lock.go @@ -0,0 +1,131 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package lock + +import ( + "sync" + + "github.com/milvus-io/milvus/internal/log" + "go.uber.org/zap" +) + +type RefLock struct { + mutex sync.RWMutex + refCounter int +} + +func (m *RefLock) ref() { + m.refCounter++ +} + +func (m *RefLock) unref() { + m.refCounter-- +} + +func newRefLock() *RefLock { + c := RefLock{ + sync.RWMutex{}, + 0, + } + return &c +} + +type KeyLock struct { + keyLocksMutex sync.Mutex + refLocks map[string]*RefLock +} + +func NewKeyLock() *KeyLock { + keyLock := KeyLock{ + refLocks: make(map[string]*RefLock), + } + return &keyLock +} + +func (k *KeyLock) Lock(key string) { + k.keyLocksMutex.Lock() + // update the key map + if keyLock, ok := k.refLocks[key]; ok { + keyLock.ref() + + k.keyLocksMutex.Unlock() + keyLock.mutex.Lock() + } else { + newKLock := newRefLock() + newKLock.mutex.Lock() + k.refLocks[key] = newKLock + newKLock.ref() + + k.keyLocksMutex.Unlock() + return + } +} + +func (k *KeyLock) Unlock(lockedKey string) { + k.keyLocksMutex.Lock() + defer k.keyLocksMutex.Unlock() + keyLock, ok := k.refLocks[lockedKey] + if !ok { + log.Warn("Unlocking non-existing key", zap.String("key", lockedKey)) + return + } + keyLock.unref() + if keyLock.refCounter == 0 { + delete(k.refLocks, lockedKey) + } + keyLock.mutex.Unlock() +} + +func (k *KeyLock) RLock(key string) { + k.keyLocksMutex.Lock() + // update the key map + if keyLock, ok := k.refLocks[key]; ok { + keyLock.ref() + + k.keyLocksMutex.Unlock() + keyLock.mutex.RLock() + } else { + newKLock := newRefLock() + newKLock.mutex.RLock() + k.refLocks[key] = newKLock + newKLock.ref() + + k.keyLocksMutex.Unlock() + return + } +} + +func (k *KeyLock) RUnlock(lockedKey string) { + k.keyLocksMutex.Lock() + defer k.keyLocksMutex.Unlock() + keyLock, ok := k.refLocks[lockedKey] + if !ok { + log.Warn("Unlocking non-existing key", zap.String("key", lockedKey)) + return + } + keyLock.unref() + if keyLock.refCounter == 0 { + delete(k.refLocks, lockedKey) + } + keyLock.mutex.RUnlock() +} + +func (k *KeyLock) size() int { + k.keyLocksMutex.Lock() + defer k.keyLocksMutex.Unlock() + return len(k.refLocks) +} diff --git a/internal/util/lock/key_lock_test.go b/internal/util/lock/key_lock_test.go new file mode 100644 index 0000000000..9d06af0a82 --- /dev/null +++ b/internal/util/lock/key_lock_test.go @@ -0,0 +1,69 @@ +package lock + +import ( + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestKeyLock(t *testing.T) { + keys := []string{"Milvus", "Blazing", "Fast"} + + keyLock := NewKeyLock() + + keyLock.Lock(keys[0]) + keyLock.Lock(keys[1]) + keyLock.Lock(keys[2]) + + // should work + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + keyLock.Lock(keys[0]) + keyLock.Unlock(keys[0]) + wg.Done() + }() + + go func() { + keyLock.Lock(keys[0]) + keyLock.Unlock(keys[0]) + wg.Done() + }() + + assert.Equal(t, keyLock.size(), 3) + + time.Sleep(10 * time.Millisecond) + keyLock.Unlock(keys[0]) + keyLock.Unlock(keys[1]) + keyLock.Unlock(keys[2]) + wg.Wait() + + assert.Equal(t, keyLock.size(), 0) +} + +func TestKeyRLock(t *testing.T) { + keys := []string{"Milvus", "Blazing", "Fast"} + + keyLock := NewKeyLock() + + keyLock.RLock(keys[0]) + keyLock.RLock(keys[0]) + + // should work + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + keyLock.Lock(keys[0]) + keyLock.Unlock(keys[0]) + wg.Done() + }() + + time.Sleep(10 * time.Millisecond) + keyLock.RUnlock(keys[0]) + keyLock.RUnlock(keys[0]) + + wg.Wait() + assert.Equal(t, keyLock.size(), 0) +}