From c8f9f22c4a58cdc27ee80f07532182efaa2bf689 Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 8 Sep 2023 16:41:16 +0800 Subject: [PATCH] Fix segment loader return false success (#26926) `waitSegmentLoadDone` did not check waitCh result is success or failure after load return without error, delegator will assume all segments are loaded This PR changes waitCh to loadResult with `sync.Cond` with `atomic.Int32` to represent status Signed-off-by: Congqi Xia --- .../querynodev2/segments/segment_loader.go | 70 ++++++++--- .../segments/segment_loader_test.go | 119 ++++++++++++++++++ 2 files changed, 173 insertions(+), 16 deletions(-) diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index 7447be73e5..d226807768 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -28,6 +28,7 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" + "go.uber.org/atomic" "go.uber.org/zap" "golang.org/x/sync/errgroup" @@ -113,12 +114,37 @@ func NewLoader( loader := &segmentLoader{ manager: manager, cm: cm, - loadingSegments: typeutil.NewConcurrentMap[int64, chan struct{}](), + loadingSegments: typeutil.NewConcurrentMap[int64, *loadResult](), } return loader } +type loadStatus = int32 + +const ( + loading loadStatus = iota + 1 + success + failure +) + +type loadResult struct { + status *atomic.Int32 + cond *sync.Cond +} + +func newLoadResult() *loadResult { + return &loadResult{ + status: atomic.NewInt32(loading), + cond: sync.NewCond(&sync.Mutex{}), + } +} + +func (r *loadResult) SetResult(status loadStatus) { + r.status.CompareAndSwap(loading, status) + r.cond.Broadcast() +} + // segmentLoader is only responsible for loading the field data from binlog type segmentLoader struct { manager *Manager @@ -126,7 +152,7 @@ type segmentLoader struct { mut sync.Mutex // The channel will be closed as the segment loaded - loadingSegments *typeutil.ConcurrentMap[int64, chan struct{}] + loadingSegments *typeutil.ConcurrentMap[int64, *loadResult] committedResource LoadResource } @@ -259,7 +285,7 @@ func (loader *segmentLoader) prepare(segmentType SegmentType, version int64, seg if len(loader.manager.Segment.GetBy(WithType(segmentType), WithID(segment.GetSegmentID()))) == 0 && !loader.loadingSegments.Contain(segment.GetSegmentID()) { infos = append(infos, segment) - loader.loadingSegments.Insert(segment.GetSegmentID(), make(chan struct{})) + loader.loadingSegments.Insert(segment.GetSegmentID(), newLoadResult()) } else { // try to update segment version before skip load operation loader.manager.Segment.UpdateSegmentBy(IncreaseVersion(version), @@ -278,13 +304,9 @@ func (loader *segmentLoader) unregister(segments ...*querypb.SegmentLoadInfo) { loader.mut.Lock() defer loader.mut.Unlock() for i := range segments { - waitCh, ok := loader.loadingSegments.GetAndRemove(segments[i].GetSegmentID()) + result, ok := loader.loadingSegments.GetAndRemove(segments[i].GetSegmentID()) if ok { - select { - case <-waitCh: - default: // close wait channel for failed task - close(waitCh) - } + result.SetResult(failure) } } } @@ -292,9 +314,9 @@ func (loader *segmentLoader) unregister(segments ...*querypb.SegmentLoadInfo) { func (loader *segmentLoader) notifyLoadFinish(segments ...*querypb.SegmentLoadInfo) { for _, loadInfo := range segments { - waitCh, ok := loader.loadingSegments.Get(loadInfo.GetSegmentID()) + result, ok := loader.loadingSegments.Get(loadInfo.GetSegmentID()) if ok { - close(waitCh) + result.SetResult(success) } } } @@ -395,18 +417,34 @@ func (loader *segmentLoader) waitSegmentLoadDone(ctx context.Context, segmentTyp continue } - waitCh, ok := loader.loadingSegments.Get(segmentID) + result, ok := loader.loadingSegments.Get(segmentID) if !ok { log.Warn("segment was removed from the loading map early", zap.Int64("segmentID", segmentID)) return errors.New("segment was removed from the loading map early") } log.Info("wait segment loaded...", zap.Int64("segmentID", segmentID)) - select { - case <-ctx.Done(): - return ctx.Err() - case <-waitCh: + + signal := make(chan struct{}) + go func() { + select { + case <-signal: + case <-ctx.Done(): + result.cond.Broadcast() + } + }() + result.cond.L.Lock() + for result.status.Load() == loading && ctx.Err() == nil { + result.cond.Wait() } + result.cond.L.Unlock() + close(signal) + + if result.status.Load() == failure { + log.Warn("failed to wait segment loaded", zap.Int64("segmentID", segmentID)) + return merr.WrapErrSegmentLack(segmentID, "failed to wait segment loaded") + } + log.Info("segment loaded...", zap.Int64("segmentID", segmentID)) } return nil diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index 19230c2076..4d38719ac2 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -20,9 +20,12 @@ import ( "context" "math/rand" "testing" + "time" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/storage" @@ -527,6 +530,122 @@ func (suite *SegmentLoaderSuite) TestRunOutMemory() { suite.Error(err) } +type SegmentLoaderDetailSuite struct { + suite.Suite + + loader *segmentLoader + manager *Manager + segmentManager *MockSegmentManager + collectionManager *MockCollectionManager + + rootPath string + chunkManager storage.ChunkManager + + // Data + collectionID int64 + partitionID int64 + segmentID int64 + schema *schemapb.CollectionSchema + segmentNum int +} + +func (suite *SegmentLoaderDetailSuite) SetupSuite() { + paramtable.Init() + suite.rootPath = suite.T().Name() + suite.collectionID = rand.Int63() + suite.partitionID = rand.Int63() + suite.segmentID = rand.Int63() + suite.segmentNum = 5 + suite.schema = GenTestCollectionSchema("test", schemapb.DataType_Int64) +} + +func (suite *SegmentLoaderDetailSuite) SetupTest() { + // Dependencies + suite.collectionManager = NewMockCollectionManager(suite.T()) + suite.segmentManager = NewMockSegmentManager(suite.T()) + suite.manager = &Manager{ + Segment: suite.segmentManager, + Collection: suite.collectionManager, + } + + ctx := context.Background() + chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) + suite.loader = NewLoader(suite.manager, suite.chunkManager) + initcore.InitRemoteChunkManager(paramtable.Get()) + + // Data + schema := GenTestCollectionSchema("test", schemapb.DataType_Int64) + + indexMeta := GenTestIndexMeta(suite.collectionID, schema) + loadMeta := &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + CollectionID: suite.collectionID, + PartitionIDs: []int64{suite.partitionID}, + } + + collection := NewCollection(suite.collectionID, schema, indexMeta, loadMeta.GetLoadType()) + suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection).Maybe() +} + +func (suite *SegmentLoaderDetailSuite) TestWaitSegmentLoadDone() { + suite.Run("wait_success", func() { + idx := 0 + + var infos []*querypb.SegmentLoadInfo + suite.segmentManager.EXPECT().GetBy(mock.Anything, mock.Anything).Return(nil) + suite.segmentManager.EXPECT().GetWithType(suite.segmentID, SegmentTypeSealed).RunAndReturn(func(segmentID int64, segmentType commonpb.SegmentState) Segment { + defer func() { idx++ }() + if idx == 0 { + go func() { + <-time.After(time.Second) + suite.loader.notifyLoadFinish(infos...) + }() + } + return nil + }) + infos = suite.loader.prepare(SegmentTypeSealed, 0, &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 100, + }) + + err := suite.loader.waitSegmentLoadDone(context.Background(), SegmentTypeSealed, suite.segmentID) + suite.NoError(err) + }) + + suite.Run("wait_failure", func() { + + suite.SetupTest() + + var idx int + var infos []*querypb.SegmentLoadInfo + suite.segmentManager.EXPECT().GetBy(mock.Anything, mock.Anything).Return(nil) + suite.segmentManager.EXPECT().GetWithType(suite.segmentID, SegmentTypeSealed).RunAndReturn(func(segmentID int64, segmentType commonpb.SegmentState) Segment { + defer func() { idx++ }() + if idx == 0 { + go func() { + <-time.After(time.Second) + suite.loader.unregister(infos...) + }() + } + + return nil + }) + infos = suite.loader.prepare(SegmentTypeSealed, 0, &querypb.SegmentLoadInfo{ + SegmentID: suite.segmentID, + PartitionID: suite.partitionID, + CollectionID: suite.collectionID, + NumOfRows: 100, + }) + + err := suite.loader.waitSegmentLoadDone(context.Background(), SegmentTypeSealed, suite.segmentID) + suite.Error(err) + }) +} + func TestSegmentLoader(t *testing.T) { suite.Run(t, &SegmentLoaderSuite{}) + suite.Run(t, &SegmentLoaderDetailSuite{}) }