diff --git a/internal/querynode/load_segment_task.go b/internal/querynode/load_segment_task.go index 85ed43851a..e68c3fb8e4 100644 --- a/internal/querynode/load_segment_task.go +++ b/internal/querynode/load_segment_task.go @@ -78,51 +78,53 @@ func (l *loadSegmentsTask) Execute(ctx context.Context) error { segmentIDs := lo.Map(l.req.Infos, func(info *queryPb.SegmentLoadInfo, idx int) UniqueID { return info.SegmentID }) l.node.metaReplica.addSegmentsLoadingList(segmentIDs) defer l.node.metaReplica.removeSegmentsLoadingList(segmentIDs) - err := l.node.loader.LoadSegment(l.ctx, l.req, segmentTypeSealed) - if err != nil { + loadDoneSegmentIDs, loadErr := l.node.loader.LoadSegment(l.ctx, l.req, segmentTypeSealed) + if len(loadDoneSegmentIDs) > 0 { + vchanName := make([]string, 0) + for _, deltaPosition := range l.req.DeltaPositions { + vchanName = append(vchanName, deltaPosition.ChannelName) + } + + // TODO delta channel need to released 1. if other watchDeltaChannel fail 2. when segment release + err := l.watchDeltaChannel(vchanName) + if err != nil { + // roll back + for _, segment := range l.req.Infos { + l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) + } + log.Warn("failed to watch Delta channel while load segment", zap.Int64("collectionID", l.req.CollectionID), + zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) + return err + } + + runningGroup, groupCtx := errgroup.WithContext(l.ctx) + for _, deltaPosition := range l.req.DeltaPositions { + pos := deltaPosition + runningGroup.Go(func() error { + // reload data from dml channel + return l.node.loader.FromDmlCPLoadDelete(groupCtx, l.req.CollectionID, pos, + lo.FilterMap(l.req.Infos, func(info *queryPb.SegmentLoadInfo, _ int) (int64, bool) { + return info.GetSegmentID(), funcutil.SliceContain(loadDoneSegmentIDs, info.SegmentID) && info.GetInsertChannel() == pos.GetChannelName() + })) + }) + } + err = runningGroup.Wait() + if err != nil { + for _, segment := range l.req.Infos { + l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) + } + for _, vchannel := range vchanName { + l.node.dataSyncService.removeEmptyFlowGraphByChannel(l.req.CollectionID, vchannel) + } + log.Warn("failed to load delete data while load segment", zap.Int64("collectionID", l.req.CollectionID), + zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) + return err + } + } + if loadErr != nil { log.Warn("failed to load segment", zap.Int64("collectionID", l.req.CollectionID), - zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) - return err - } - vchanName := make([]string, 0) - for _, deltaPosition := range l.req.DeltaPositions { - vchanName = append(vchanName, deltaPosition.ChannelName) - } - - // TODO delta channel need to released 1. if other watchDeltaChannel fail 2. when segment release - err = l.watchDeltaChannel(vchanName) - if err != nil { - // roll back - for _, segment := range l.req.Infos { - l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) - } - log.Warn("failed to watch Delta channel while load segment", zap.Int64("collectionID", l.req.CollectionID), - zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) - return err - } - - runningGroup, groupCtx := errgroup.WithContext(l.ctx) - for _, deltaPosition := range l.req.DeltaPositions { - pos := deltaPosition - runningGroup.Go(func() error { - // reload data from dml channel - return l.node.loader.FromDmlCPLoadDelete(groupCtx, l.req.CollectionID, pos, - lo.FilterMap(l.req.Infos, func(info *queryPb.SegmentLoadInfo, _ int) (int64, bool) { - return info.GetSegmentID(), info.GetInsertChannel() == pos.GetChannelName() - })) - }) - } - err = runningGroup.Wait() - if err != nil { - for _, segment := range l.req.Infos { - l.node.metaReplica.removeSegment(segment.SegmentID, segmentTypeSealed) - } - for _, vchannel := range vchanName { - l.node.dataSyncService.removeEmptyFlowGraphByChannel(l.req.CollectionID, vchannel) - } - log.Warn("failed to load delete data while load segment", zap.Int64("collectionID", l.req.CollectionID), - zap.Int64("replicaID", l.req.ReplicaID), zap.Error(err)) - return err + zap.Int64("replicaID", l.req.ReplicaID), zap.Error(loadErr)) + return loadErr } log.Info("LoadSegmentTask Execute done", zap.Int64("collectionID", l.req.CollectionID), diff --git a/internal/querynode/load_segment_task_test.go b/internal/querynode/load_segment_task_test.go index 9a01bf783e..6f3fef618d 100644 --- a/internal/querynode/load_segment_task_test.go +++ b/internal/querynode/load_segment_task_test.go @@ -42,6 +42,12 @@ func TestTask_loadSegmentsTask(t *testing.T) { defer cancel() schema := genTestCollectionSchema() + node, err := genSimpleQueryNode(ctx) + assert.NoError(t, err) + testVChannel := "by-dev-rootcoord-dml_1_2021v1" + fieldBinlog, statsLog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) + assert.NoError(t, err) + genLoadEmptySegmentsRequest := func() *querypb.LoadSegmentsRequest { req := &querypb.LoadSegmentsRequest{ Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, 0), @@ -79,14 +85,8 @@ func TestTask_loadSegmentsTask(t *testing.T) { }) t.Run("test execute grpc", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) - fieldBinlog, statsLog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) - assert.NoError(t, err) - req := &querypb.LoadSegmentsRequest{ Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Schema: schema, @@ -113,14 +113,8 @@ func TestTask_loadSegmentsTask(t *testing.T) { }) t.Run("test repeated load", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) - fieldBinlog, statsLog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) - assert.NoError(t, err) - req := &querypb.LoadSegmentsRequest{ Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), Schema: schema, @@ -156,121 +150,8 @@ func TestTask_loadSegmentsTask(t *testing.T) { assert.Equal(t, 1, num) }) - t.Run("test FromDmlCPLoadDelete", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) - - vDmChannel := "by-dev-rootcoord-dml_1_2021v1" - pDmChannel := funcutil.ToPhysicalChannel(vDmChannel) - stream, err := node.factory.NewMsgStream(node.queryNodeLoopCtx) - assert.Nil(t, err) - stream.AsProducer([]string{pDmChannel}) - timeTickMsg := &msgstream.TimeTickMsg{ - BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{1}, - }, - TimeTickMsg: internalpb.TimeTickMsg{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_TimeTick, - Timestamp: 100, - }, - }, - } - - deleteMsg := &msgstream.DeleteMsg{ - BaseMsg: msgstream.BaseMsg{ - HashValues: []uint32{1, 1, 1}, - }, - DeleteRequest: internalpb.DeleteRequest{ - Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_Delete, - Timestamp: 110, - }, - CollectionID: defaultCollectionID, - PartitionID: defaultPartitionID, - PrimaryKeys: &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{1, 2, 3}, - }, - }, - }, - Timestamps: []Timestamp{110, 110, 110}, - NumRows: 3, - }, - } - - pos1, err := stream.ProduceMark(&msgstream.MsgPack{Msgs: []msgstream.TsMsg{timeTickMsg}}) - assert.NoError(t, err) - msgIDs, ok := pos1[pDmChannel] - assert.True(t, ok) - assert.Equal(t, 1, len(msgIDs)) - err = stream.Produce(&msgstream.MsgPack{Msgs: []msgstream.TsMsg{deleteMsg}}) - assert.NoError(t, err) - - // to stop reader from cp - go func() { - for { - select { - case <-ctx.Done(): - break - default: - timeTickMsg.Base.Timestamp += 100 - stream.Produce(&msgstream.MsgPack{Msgs: []msgstream.TsMsg{timeTickMsg}}) - time.Sleep(200 * time.Millisecond) - } - } - }() - - segmentID := defaultSegmentID + 1 - fieldBinlog, statsLog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, segmentID, defaultMsgLength, schema) - assert.NoError(t, err) - - req := &querypb.LoadSegmentsRequest{ - Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), - Schema: schema, - Infos: []*querypb.SegmentLoadInfo{ - { - SegmentID: segmentID, - PartitionID: defaultPartitionID, - CollectionID: defaultCollectionID, - BinlogPaths: fieldBinlog, - NumOfRows: defaultMsgLength, - Statslogs: statsLog, - InsertChannel: vDmChannel, - }, - }, - DeltaPositions: []*internalpb.MsgPosition{ - { - ChannelName: vDmChannel, - MsgID: msgIDs[0].Serialize(), - Timestamp: 100, - }, - }, - } - - task := loadSegmentsTask{ - baseTask: baseTask{ - ctx: ctx, - }, - req: req, - node: node, - } - err = task.PreExecute(ctx) - assert.NoError(t, err) - err = task.Execute(ctx) - assert.NoError(t, err) - segment, err := node.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) - assert.NoError(t, err) - - // has reload 3 delete log from dm channel, so next delete offset should be 3 - offset := segment.segmentPreDelete(1) - assert.Equal(t, int64(3), offset) - }) - t.Run("test OOM", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) + node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) totalRAM := int64(hardware.GetMemoryCount()) @@ -311,9 +192,8 @@ func TestTask_loadSegmentsTask(t *testing.T) { assert.Contains(t, err.Error(), "OOM") }) + factory := node.loader.factory t.Run("test FromDmlCPLoadDelete failed", func(t *testing.T) { - node, err := genSimpleQueryNode(ctx) - assert.NoError(t, err) node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) msgStream := &LoadDeleteMsgStream{} @@ -334,7 +214,7 @@ func TestTask_loadSegmentsTask(t *testing.T) { Infos: []*querypb.SegmentLoadInfo{segmentLoadInfo}, DeltaPositions: []*internalpb.MsgPosition{ { - ChannelName: "by-dev-rootcoord-dml-test_1_2021v2", + ChannelName: testVChannel, MsgID: rmq.SerializeRmqID(0), }, }, @@ -357,6 +237,177 @@ func TestTask_loadSegmentsTask(t *testing.T) { fgNum := node.dataSyncService.getFlowGraphNum() assert.Equal(t, 0, fgNum) }) + + node.loader.factory = factory + pDmChannel := funcutil.ToPhysicalChannel(testVChannel) + stream, err := node.factory.NewMsgStream(node.queryNodeLoopCtx) + assert.Nil(t, err) + stream.AsProducer([]string{pDmChannel}) + timeTickMsg := &msgstream.TimeTickMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{1}, + }, + TimeTickMsg: internalpb.TimeTickMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_TimeTick, + Timestamp: 100, + }, + }, + } + + deleteMsg := &msgstream.DeleteMsg{ + BaseMsg: msgstream.BaseMsg{ + HashValues: []uint32{1, 1, 1}, + }, + DeleteRequest: internalpb.DeleteRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Delete, + Timestamp: 110, + }, + CollectionID: defaultCollectionID, + PartitionID: defaultPartitionID, + PrimaryKeys: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3}, + }, + }, + }, + Timestamps: []Timestamp{110, 110, 110}, + NumRows: 3, + }, + } + + pos1, err := stream.ProduceMark(&msgstream.MsgPack{Msgs: []msgstream.TsMsg{timeTickMsg}}) + assert.NoError(t, err) + msgIDs, ok := pos1[pDmChannel] + assert.True(t, ok) + assert.Equal(t, 1, len(msgIDs)) + err = stream.Produce(&msgstream.MsgPack{Msgs: []msgstream.TsMsg{deleteMsg}}) + assert.NoError(t, err) + + // to stop reader from cp + go func() { + for { + select { + case <-ctx.Done(): + break + default: + timeTickMsg.Base.Timestamp += 100 + stream.Produce(&msgstream.MsgPack{Msgs: []msgstream.TsMsg{timeTickMsg}}) + time.Sleep(200 * time.Millisecond) + } + } + }() + + t.Run("test FromDmlCPLoadDelete", func(t *testing.T) { + node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) + + req := &querypb.LoadSegmentsRequest{ + Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: defaultSegmentID, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + BinlogPaths: fieldBinlog, + NumOfRows: defaultMsgLength, + Statslogs: statsLog, + InsertChannel: testVChannel, + }, + }, + DeltaPositions: []*internalpb.MsgPosition{ + { + ChannelName: testVChannel, + MsgID: msgIDs[0].Serialize(), + Timestamp: 100, + }, + }, + } + + task := loadSegmentsTask{ + baseTask: baseTask{ + ctx: ctx, + }, + req: req, + node: node, + } + err = task.PreExecute(ctx) + assert.NoError(t, err) + err = task.Execute(ctx) + assert.NoError(t, err) + segment, err := node.metaReplica.getSegmentByID(defaultSegmentID, segmentTypeSealed) + assert.NoError(t, err) + + // has reload 3 delete log from dm channel, so next delete offset should be 3 + offset := segment.segmentPreDelete(1) + assert.Equal(t, int64(3), offset) + }) + + t.Run("test load with partial success", func(t *testing.T) { + deltaChannel, err := funcutil.ConvertChannelName(testVChannel, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta) + assert.NoError(t, err) + + node.metaReplica.removeSegment(defaultSegmentID, segmentTypeSealed) + node.dataSyncService.removeFlowGraphsByDMLChannels([]Channel{testVChannel}) + node.dataSyncService.removeFlowGraphsByDeltaChannels([]Channel{deltaChannel}) + + fakeFieldBinlog, fakeStatsBinlog, err := getFakeBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, defaultMsgLength, schema) + assert.NoError(t, err) + + segmentID1 := defaultSegmentID + segmentID2 := defaultSegmentID + 1 + req := &querypb.LoadSegmentsRequest{ + Base: genCommonMsgBase(commonpb.MsgType_LoadSegments, node.session.ServerID), + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: segmentID1, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + BinlogPaths: fieldBinlog, + NumOfRows: defaultMsgLength, + Statslogs: statsLog, + InsertChannel: testVChannel, + }, + { + SegmentID: segmentID2, + PartitionID: defaultPartitionID, + CollectionID: defaultCollectionID, + BinlogPaths: fakeFieldBinlog, + NumOfRows: defaultMsgLength, + Statslogs: fakeStatsBinlog, + InsertChannel: testVChannel, + }, + }, + DeltaPositions: []*internalpb.MsgPosition{ + { + ChannelName: testVChannel, + MsgID: msgIDs[0].Serialize(), + Timestamp: 100, + }, + }, + } + + task := loadSegmentsTask{ + baseTask: baseTask{ + ctx: ctx, + }, + req: req, + node: node, + } + err = task.PreExecute(ctx) + assert.NoError(t, err) + err = task.Execute(ctx) + assert.Error(t, err) + exist, err := node.metaReplica.hasSegment(segmentID1, segmentTypeSealed) + assert.NoError(t, err) + assert.True(t, exist) + exist, err = node.metaReplica.hasSegment(segmentID2, segmentTypeSealed) + assert.NoError(t, err) + assert.False(t, exist) + }) } func TestTask_loadSegmentsTaskLoadDelta(t *testing.T) { diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 69f69f657b..7bebb7c243 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -338,7 +338,7 @@ func loadIndexForSegment(ctx context.Context, node *QueryNode, segmentID UniqueI }, } - err = loader.LoadSegment(ctx, req, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req, segmentTypeSealed) if err != nil { return err } @@ -964,6 +964,54 @@ func genSimpleInsertMsg(schema *schemapb.CollectionSchema, numRows int) (*msgstr }, nil } +func getFakeBinLog(ctx context.Context, + collectionID UniqueID, + partitionID UniqueID, + segmentID UniqueID, + msgLength int, + schema *schemapb.CollectionSchema) ([]*datapb.FieldBinlog, []*datapb.FieldBinlog, error) { + binLogs, statsLogs, err := genStorageBlob(collectionID, + partitionID, + segmentID, + msgLength, + schema) + if err != nil { + return nil, nil, err + } + + // gen fake insert binlog path, don't write data to minio + fieldBinlog := make([]*datapb.FieldBinlog, 0) + for _, blob := range binLogs { + fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64) + if err != nil { + return nil, nil, err + } + + k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) + fieldBinlog = append(fieldBinlog, &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []*datapb.Binlog{{LogPath: path.Join("insert-log", k, "notExistKey")}}, + }) + } + + // gen fake stats binlog path, don't write data to minio + statsBinlog := make([]*datapb.FieldBinlog, 0) + for _, blob := range statsLogs { + fieldID, err := strconv.ParseInt(blob.GetKey(), 10, 64) + if err != nil { + return nil, nil, err + } + + k := JoinIDPath(collectionID, partitionID, segmentID, fieldID) + statsBinlog = append(statsBinlog, &datapb.FieldBinlog{ + FieldID: fieldID, + Binlogs: []*datapb.Binlog{{LogPath: path.Join("delta-log", k, "notExistKey")}}, + }) + } + + return fieldBinlog, statsBinlog, err +} + func saveBinLog(ctx context.Context, collectionID UniqueID, partitionID UniqueID, diff --git a/internal/querynode/segment_loader.go b/internal/querynode/segment_loader.go index 79f5a07331..3da8c38df5 100644 --- a/internal/querynode/segment_loader.go +++ b/internal/querynode/segment_loader.go @@ -27,6 +27,7 @@ import ( "strconv" "time" + "go.uber.org/multierr" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -86,9 +87,9 @@ func (loader *segmentLoader) getFieldType(segment *Segment, fieldID FieldID) (sc return coll.getFieldType(fieldID) } -func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadSegmentsRequest, segmentType segmentType) error { +func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadSegmentsRequest, segmentType segmentType) ([]UniqueID, error) { if req.Base == nil { - return fmt.Errorf("nil base message when load segment, collectionID = %d", req.CollectionID) + return nil, fmt.Errorf("nil base message when load segment, collectionID = %d", req.CollectionID) } log := log.With(zap.Int64("collectionID", req.CollectionID), zap.String("segmentType", segmentType.String())) @@ -97,7 +98,7 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS if segmentNum == 0 { log.Warn("find no valid segment target, skip load segment", zap.Any("request", req)) - return nil + return nil, nil } log.Info("segmentLoader start loading...", zap.Any("segmentNum", segmentNum)) @@ -125,13 +126,16 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS log.Error("load failed, OOM if loaded", zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err)) - return err + return nil, err } - newSegments := make(map[UniqueID]*Segment, len(req.Infos)) - segmentGC := func() { - for _, s := range newSegments { - deleteSegment(s) + newSegments := make(map[UniqueID]*Segment, segmentNum) + loadDoneSegmentIDSet := typeutil.NewConcurrentSet[int64]() + segmentGC := func(force bool) { + for id, s := range newSegments { + if force || !loadDoneSegmentIDSet.Contain(id) { + deleteSegment(s) + } } debug.FreeOSMemory() } @@ -144,8 +148,8 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS collection, err := loader.metaReplica.getCollectionByID(collectionID) if err != nil { - segmentGC() - return err + segmentGC(true) + return nil, err } segment, err := newSegment(collection, segmentID, partitionID, collectionID, vChannelID, segmentType, req.GetVersion(), loader.cgoPool) @@ -154,8 +158,8 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS zap.Int64("partitionID", partitionID), zap.Int64("segmentID", segmentID), zap.Error(err)) - segmentGC() - return err + segmentGC(true) + return nil, err } newSegments[segmentID] = segment @@ -177,6 +181,7 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS return err } + loadDoneSegmentIDSet.Insert(segmentID) metrics.QueryNodeLoadSegmentLatency.WithLabelValues(fmt.Sprint(Params.QueryNodeCfg.GetNodeID())).Observe(float64(tr.ElapseSpan().Milliseconds())) return nil @@ -187,29 +192,36 @@ func (loader *segmentLoader) LoadSegment(ctx context.Context, req *querypb.LoadS log.Info("start to load segments in parallel", zap.Int("segmentNum", segmentNum), zap.Int("concurrencyLevel", concurrencyLevel)) - err = funcutil.ProcessFuncParallel(segmentNum, + loadErr := funcutil.ProcessFuncParallel(segmentNum, concurrencyLevel, loadFileFunc, "loadSegmentFunc") - if err != nil { - segmentGC() - return err - } - // set segment to meta replica - for _, s := range newSegments { - err = loader.metaReplica.setSegment(s) + // set segment which has been loaded done to meta replica + failedSetMetaSegmentIDs := make([]UniqueID, 0) + for _, id := range loadDoneSegmentIDSet.Collect() { + segment := newSegments[id] + err = loader.metaReplica.setSegment(segment) if err != nil { log.Error("load segment failed, set segment to meta failed", - zap.Int64("collectionID", s.collectionID), - zap.Int64("partitionID", s.partitionID), - zap.Int64("segmentID", s.segmentID), + zap.Int64("collectionID", segment.collectionID), + zap.Int64("partitionID", segment.partitionID), + zap.Int64("segmentID", segment.segmentID), zap.Int64("loadSegmentRequest msgID", req.Base.MsgID), zap.Error(err)) - segmentGC() - return err + failedSetMetaSegmentIDs = append(failedSetMetaSegmentIDs, id) + loadDoneSegmentIDSet.Remove(id) } } + if len(failedSetMetaSegmentIDs) > 0 { + err = fmt.Errorf("load segment failed, set segment to meta failed, segmentIDs: %v", failedSetMetaSegmentIDs) + } - return nil + err = multierr.Combine(loadErr, err) + if err != nil { + segmentGC(false) + return loadDoneSegmentIDSet.Collect(), err + } + + return loadDoneSegmentIDSet.Collect(), nil } func (loader *segmentLoader) loadFiles(ctx context.Context, segment *Segment, diff --git a/internal/querynode/segment_loader_test.go b/internal/querynode/segment_loader_test.go index d964ddbdb0..b638c3a51e 100644 --- a/internal/querynode/segment_loader_test.go +++ b/internal/querynode/segment_loader_test.go @@ -61,7 +61,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) { req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -77,10 +77,56 @@ func TestSegmentLoader_loadSegment(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req, segmentTypeSealed) assert.NoError(t, err) }) + t.Run("test load segment error due to partial success", func(t *testing.T) { + node, err := genSimpleQueryNode(ctx) + assert.NoError(t, err) + + loader := node.loader + assert.NotNil(t, loader) + + existPatitionID := defaultPartitionID + notExistPartitionID := defaultPartitionID + 1 + segmentID1 := defaultSegmentID + 1 + segmentID2 := defaultSegmentID + 2 + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadSegments, + MsgID: rand.Int63(), + }, + DstNodeID: 0, + Schema: schema, + Infos: []*querypb.SegmentLoadInfo{ + { + SegmentID: segmentID1, + PartitionID: existPatitionID, + CollectionID: defaultCollectionID, + BinlogPaths: fieldBinlog, + }, + { + SegmentID: segmentID2, + PartitionID: notExistPartitionID, + CollectionID: defaultCollectionID, + BinlogPaths: fieldBinlog, + }, + }, + } + + loadDoneSegmentIDs, err := loader.LoadSegment(ctx, req, segmentTypeSealed) + assert.Error(t, err) + assert.Equal(t, 1, len(loadDoneSegmentIDs)) + assert.Equal(t, segmentID1, loadDoneSegmentIDs[0]) + exist, err := node.metaReplica.hasSegment(segmentID1, segmentTypeSealed) + assert.NoError(t, err) + assert.True(t, exist) + exist, err = node.metaReplica.hasSegment(segmentID2, segmentTypeSealed) + assert.NoError(t, err) + assert.False(t, exist) + }) + t.Run("test set segment error due to without partition", func(t *testing.T) { node, err := genSimpleQueryNode(ctx) assert.NoError(t, err) @@ -91,16 +137,17 @@ func TestSegmentLoader_loadSegment(t *testing.T) { loader := node.loader assert.NotNil(t, loader) + segmentID := defaultSegmentID + 3 req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, Schema: schema, Infos: []*querypb.SegmentLoadInfo{ { - SegmentID: defaultSegmentID, + SegmentID: segmentID, PartitionID: defaultPartitionID, CollectionID: defaultCollectionID, BinlogPaths: fieldBinlog, @@ -108,7 +155,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req, segmentTypeSealed) assert.Error(t, err) }) @@ -121,7 +168,7 @@ func TestSegmentLoader_loadSegment(t *testing.T) { req := &querypb.LoadSegmentsRequest{} - err = loader.LoadSegment(ctx, req, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req, segmentTypeSealed) assert.Error(t, err) }) } @@ -234,7 +281,7 @@ func TestSegmentLoader_invalid(t *testing.T) { req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -247,7 +294,7 @@ func TestSegmentLoader_invalid(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req, segmentTypeSealed) assert.Error(t, err) }) @@ -272,7 +319,7 @@ func TestSegmentLoader_invalid(t *testing.T) { req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -285,7 +332,7 @@ func TestSegmentLoader_invalid(t *testing.T) { }, }, } - err = loader.LoadSegment(ctx, req, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req, segmentTypeSealed) assert.Error(t, err) }) @@ -297,7 +344,7 @@ func TestSegmentLoader_invalid(t *testing.T) { req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -310,7 +357,7 @@ func TestSegmentLoader_invalid(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req, commonpb.SegmentState_Dropped) + _, err = loader.LoadSegment(ctx, req, commonpb.SegmentState_Dropped) assert.Error(t, err) }) @@ -500,7 +547,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { segmentID1 := UniqueID(100) req1 := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -516,7 +563,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req1, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req1, segmentTypeSealed) assert.NoError(t, err) segment1, err := loader.metaReplica.getSegmentByID(segmentID1, segmentTypeSealed) @@ -526,7 +573,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { segmentID2 := UniqueID(101) req2 := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -542,7 +589,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req2, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req2, segmentTypeSealed) assert.NoError(t, err) segment2, err := loader.metaReplica.getSegmentByID(segmentID2, segmentTypeSealed) @@ -561,7 +608,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { segmentID1 := UniqueID(100) req1 := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -576,7 +623,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req1, segmentTypeGrowing) + _, err = loader.LoadSegment(ctx, req1, segmentTypeGrowing) assert.NoError(t, err) segment1, err := loader.metaReplica.getSegmentByID(segmentID1, segmentTypeGrowing) @@ -586,7 +633,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { segmentID2 := UniqueID(101) req2 := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -602,7 +649,7 @@ func TestSegmentLoader_testLoadGrowingAndSealed(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req2, segmentTypeGrowing) + _, err = loader.LoadSegment(ctx, req2, segmentTypeGrowing) assert.NoError(t, err) segment2, err := loader.metaReplica.getSegmentByID(segmentID2, segmentTypeGrowing) @@ -645,7 +692,7 @@ func TestSegmentLoader_testLoadSealedSegmentWithIndex(t *testing.T) { req := &querypb.LoadSegmentsRequest{ Base: &commonpb.MsgBase{ - MsgType: commonpb.MsgType_WatchQueryChannels, + MsgType: commonpb.MsgType_LoadSegments, MsgID: rand.Int63(), }, DstNodeID: 0, @@ -662,7 +709,7 @@ func TestSegmentLoader_testLoadSealedSegmentWithIndex(t *testing.T) { }, } - err = loader.LoadSegment(ctx, req, segmentTypeSealed) + _, err = loader.LoadSegment(ctx, req, segmentTypeSealed) assert.NoError(t, err) segment, err := node.metaReplica.getSegmentByID(segmentID, segmentTypeSealed) diff --git a/internal/querynode/watch_dm_channels_task.go b/internal/querynode/watch_dm_channels_task.go index 0843b6ba2e..a6e189a75a 100644 --- a/internal/querynode/watch_dm_channels_task.go +++ b/internal/querynode/watch_dm_channels_task.go @@ -204,7 +204,7 @@ func (w *watchDmChannelsTask) LoadGrowingSegments(ctx context.Context, collectio zap.Int64("collectionID", collectionID), zap.Int64s("unFlushedSegmentIDs", unFlushedSegmentIDs), ) - err := w.node.loader.LoadSegment(w.ctx, req, segmentTypeGrowing) + _, err := w.node.loader.LoadSegment(w.ctx, req, segmentTypeGrowing) if err != nil { log.Warn("failed to load segment", zap.Int64("collection", collectionID), zap.Error(err)) return nil, err