diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index c0c532efa3..d379515a68 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" + "github.com/milvus-io/milvus/pkg/util/conc" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" @@ -113,6 +114,7 @@ type shardDelegator struct { //dispatcherClient msgdispatcher.Client factory msgstream.Factory + sf conc.Singleflight[struct{}] loader segments.Loader tsCond *sync.Cond latestTsafe *atomic.Uint64 diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index ba0d20d3cd..b6dda08fb9 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -356,7 +356,38 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg req.Base.TargetID = req.GetDstNodeID() log.Info("worker loads segments...") - err = worker.LoadSegments(ctx, req) + + sLoad := func(ctx context.Context, req *querypb.LoadSegmentsRequest) error { + segmentID := req.GetInfos()[0].GetSegmentID() + nodeID := req.GetDstNodeID() + _, err, _ := sd.sf.Do(fmt.Sprintf("%d-%d", nodeID, segmentID), func() (struct{}, error) { + err := worker.LoadSegments(ctx, req) + return struct{}{}, err + }) + return err + } + + // separate infos into different load task + if len(req.GetInfos()) > 1 { + var reqs []*querypb.LoadSegmentsRequest + for _, info := range req.GetInfos() { + newReq := typeutil.Clone(req) + newReq.Infos = []*querypb.SegmentLoadInfo{info} + reqs = append(reqs, newReq) + } + + group, ctx := errgroup.WithContext(ctx) + for _, req := range reqs { + req := req + group.Go(func() error { + return sLoad(ctx, req) + }) + } + err = group.Wait() + } else { + err = sLoad(ctx, req) + } + if err != nil { log.Warn("worker failed to load segments", zap.Error(err)) return err diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 54c0a100b6..3210dd456b 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -177,6 +177,11 @@ func (loader *segmentLoader) Load(ctx context.Context, infos := loader.prepare(segmentType, version, segments...) defer loader.unregister(infos...) + log.With( + zap.Int64s("requestSegments", lo.Map(segments, func(s *querypb.SegmentLoadInfo, _ int) int64 { return s.GetSegmentID() })), + zap.Int64s("preparedSegments", lo.Map(infos, func(s *querypb.SegmentLoadInfo, _ int) int64 { return s.GetSegmentID() })), + ) + // continue to wait other task done log.Info("start loading...", zap.Int("segmentNum", len(segments)), zap.Int("afterFilter", len(infos))) @@ -412,6 +417,10 @@ func (loader *segmentLoader) freeRequest(resource LoadResource) { } func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentType SegmentType, segmentIDs ...int64) error { + log := log.Ctx(ctx).With( + zap.String("segmentType", segmentType.String()), + zap.Int64s("segmentIDs", segmentIDs), + ) for _, segmentID := range segmentIDs { if loader.manager.Segment.GetWithType(segmentID, segmentType) != nil { continue @@ -440,6 +449,11 @@ func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentTyp result.cond.L.Unlock() close(signal) + if ctx.Err() != nil { + log.Warn("failed to wait segment loaded due to context done", zap.Int64("segmentID", segmentID)) + return ctx.Err() + } + if result.status.Load() == failure { log.Warn("failed to wait segment loaded", zap.Int64("segmentID", segmentID)) return merr.WrapErrSegmentLack(segmentID, "failed to wait segment loaded") diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index 4d38719ac2..2067e5b5cc 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -643,6 +643,29 @@ func (suite *SegmentLoaderDetailSuite) TestWaitSegmentLoadDone() { err := suite.loader.waitSegmentLoadDone(context.Background(), SegmentTypeSealed, suite.segmentID) suite.Error(err) }) + + suite.Run("wait_timeout", func() { + + suite.SetupTest() + + suite.segmentManager.EXPECT().GetBy(mock.Anything, mock.Anything).Return(nil) + suite.segmentManager.EXPECT().GetWithType(suite.segmentID, SegmentTypeSealed).RunAndReturn(func(segmentID int64, segmentType commonpb.SegmentState) Segment { + return nil + }) + suite.loader.prepare(SegmentTypeSealed, 0, &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 100, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err := suite.loader.waitSegmentLoadDone(ctx, SegmentTypeSealed, suite.segmentID) + suite.Error(err) + suite.True(merr.IsCanceledOrTimeout(err)) + }) } func TestSegmentLoader(t *testing.T) { diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index b01201f8ff..12bf40bb80 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -516,23 +516,26 @@ func (suite *ServiceSuite) TestLoadSegments_Int64() { suite.TestWatchDmChannelsInt64() // data schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) - req := &querypb.LoadSegmentsRequest{ - Base: &commonpb.MsgBase{ - MsgID: rand.Int63(), - TargetID: suite.node.session.ServerID, - }, - CollectionID: suite.collectionID, - DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), - Schema: schema, - DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, - NeedTransfer: true, - } + infos := suite.genSegmentLoadInfos(schema) + for _, info := range infos { + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + CollectionID: suite.collectionID, + DstNodeID: suite.node.session.ServerID, + Infos: []*querypb.SegmentLoadInfo{info}, + Schema: schema, + DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, + NeedTransfer: true, + } - // LoadSegment - status, err := suite.node.LoadSegments(ctx, req) - suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + // LoadSegment + status, err := suite.node.LoadSegments(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + } } func (suite *ServiceSuite) TestLoadSegments_VarChar() { @@ -547,24 +550,28 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() { } suite.node.manager.Collection = segments.NewCollectionManager() suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, loadMeta) - req := &querypb.LoadSegmentsRequest{ - Base: &commonpb.MsgBase{ - MsgID: rand.Int63(), - TargetID: suite.node.session.ServerID, - }, - CollectionID: suite.collectionID, - DstNodeID: suite.node.session.ServerID, - Infos: suite.genSegmentLoadInfos(schema), - Schema: schema, - DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, - NeedTransfer: true, - LoadMeta: loadMeta, - } - // LoadSegment - status, err := suite.node.LoadSegments(ctx, req) - suite.NoError(err) - suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + infos := suite.genSegmentLoadInfos(schema) + for _, info := range infos { + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + CollectionID: suite.collectionID, + DstNodeID: suite.node.session.ServerID, + Infos: []*querypb.SegmentLoadInfo{info}, + Schema: schema, + DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, + NeedTransfer: true, + LoadMeta: loadMeta, + } + + // LoadSegment + status, err := suite.node.LoadSegments(ctx, req) + suite.NoError(err) + suite.Equal(commonpb.ErrorCode_Success, status.GetErrorCode()) + } } func (suite *ServiceSuite) TestLoadDeltaInt64() {