From d79aa58b379f68b7703056bdff39cacfe3f6c5cb Mon Sep 17 00:00:00 2001 From: wei liu Date: Fri, 15 Mar 2024 14:19:03 +0800 Subject: [PATCH] enhance: Speed up target recovery after query coord restart (#31240) issue: #28491 after querycoord restart, it will pull a new target, which include channel and segment list. when segments loaded on querynode has reached the target, the collection could provide search/query. but if segment list changes by time, ater querycoord pull a new target, it will takes a few minutes to catch up the target's segment distribution. and before that, query/search will fail due to lack of segments. This PR save the current loaded target to meta storein querycoord's stop progress, and recover it when query coord starts, to speed up the target recovery time. --------- Signed-off-by: Wei Liu --- internal/distributed/querycoord/service.go | 4 + internal/metastore/catalog.go | 4 + .../metastore/kv/querycoord/kv_catalog.go | 55 ++++- .../kv/querycoord/kv_catalog_test.go | 46 ++++ .../mocks/mock_querycoord_catalog.go | 137 +++++++++++ internal/proto/query_coord.proto | 24 ++ internal/querycoordv2/meta/target.go | 84 +++++++ internal/querycoordv2/meta/target_manager.go | 44 ++++ .../querycoordv2/meta/target_manager_test.go | 81 ++++++- internal/querycoordv2/server.go | 10 + tests/integration/minicluster_v2.go | 8 +- tests/integration/target/target_test.go | 221 ++++++++++++++++++ 12 files changed, 708 insertions(+), 10 deletions(-) create mode 100644 tests/integration/target/target_test.go diff --git a/internal/distributed/querycoord/service.go b/internal/distributed/querycoord/service.go index 573438c55c..f81546fa45 100644 --- a/internal/distributed/querycoord/service.go +++ b/internal/distributed/querycoord/service.go @@ -271,6 +271,10 @@ func (s *Server) start() error { return s.queryCoord.Start() } +func (s *Server) GetQueryCoord() types.QueryCoordComponent { + return s.queryCoord +} + // Stop stops QueryCoord's grpc service. func (s *Server) Stop() (err error) { Params := ¶mtable.Get().QueryCoordGrpcServerCfg diff --git a/internal/metastore/catalog.go b/internal/metastore/catalog.go index 84892de9f1..f6c23dc684 100644 --- a/internal/metastore/catalog.go +++ b/internal/metastore/catalog.go @@ -163,4 +163,8 @@ type QueryCoordCatalog interface { SaveResourceGroup(rgs ...*querypb.ResourceGroup) error RemoveResourceGroup(rgName string) error GetResourceGroups() ([]*querypb.ResourceGroup, error) + + SaveCollectionTarget(target *querypb.CollectionTarget) error + RemoveCollectionTarget(collectionID int64) error + GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) } diff --git a/internal/metastore/kv/querycoord/kv_catalog.go b/internal/metastore/kv/querycoord/kv_catalog.go index dde15d6001..2ac1fa6039 100644 --- a/internal/metastore/kv/querycoord/kv_catalog.go +++ b/internal/metastore/kv/querycoord/kv_catalog.go @@ -1,15 +1,21 @@ package querycoord import ( + "bytes" "fmt" + "io" "github.com/cockroachdb/errors" "github.com/golang/protobuf/proto" + "github.com/klauspost/compress/zstd" + "github.com/pingcap/log" "github.com/samber/lo" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/util/compressor" ) var ErrInvalidKey = errors.New("invalid load info key") @@ -22,7 +28,8 @@ const ( ReplicaMetaPrefixV1 = "queryCoord-ReplicaMeta" ResourceGroupPrefix = "queryCoord-ResourceGroup" - MetaOpsBatchSize = 128 + MetaOpsBatchSize = 128 + CollectionTargetPrefix = "queryCoord-Collection-Target" ) type Catalog struct { @@ -234,6 +241,48 @@ func (s Catalog) ReleaseReplica(collection, replica int64) error { return s.cli.Remove(key) } +func (s Catalog) SaveCollectionTarget(target *querypb.CollectionTarget) error { + k := encodeCollectionTargetKey(target.GetCollectionID()) + v, err := proto.Marshal(target) + if err != nil { + return err + } + // to reduce the target size, we do compress before write to etcd + var compressed bytes.Buffer + compressor.ZstdCompress(bytes.NewReader(v), io.Writer(&compressed), zstd.WithEncoderLevel(zstd.SpeedBetterCompression)) + err = s.cli.Save(k, compressed.String()) + if err != nil { + return err + } + return nil +} + +func (s Catalog) RemoveCollectionTarget(collectionID int64) error { + k := encodeCollectionTargetKey(collectionID) + return s.cli.Remove(k) +} + +func (s Catalog) GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) { + keys, values, err := s.cli.LoadWithPrefix(CollectionTargetPrefix) + if err != nil { + return nil, err + } + ret := make(map[int64]*querypb.CollectionTarget) + for i, v := range values { + var decompressed bytes.Buffer + compressor.ZstdDecompress(bytes.NewReader([]byte(v)), io.Writer(&decompressed)) + target := &querypb.CollectionTarget{} + if err := proto.Unmarshal(decompressed.Bytes(), target); err != nil { + // recover target from meta is a optimize policy, skip when failure happens + log.Warn("failed to unmarshal collection target", zap.String("key", keys[i]), zap.Error(err)) + continue + } + ret[target.GetCollectionID()] = target + } + + return ret, nil +} + func EncodeCollectionLoadInfoKey(collection int64) string { return fmt.Sprintf("%s/%d", CollectionLoadInfoPrefix, collection) } @@ -253,3 +302,7 @@ func encodeCollectionReplicaKey(collection int64) string { func encodeResourceGroupKey(rgName string) string { return fmt.Sprintf("%s/%s", ResourceGroupPrefix, rgName) } + +func encodeCollectionTargetKey(collection int64) string { + return fmt.Sprintf("%s/%d", CollectionTargetPrefix, collection) +} diff --git a/internal/metastore/kv/querycoord/kv_catalog_test.go b/internal/metastore/kv/querycoord/kv_catalog_test.go index 4a37760321..349e94539f 100644 --- a/internal/metastore/kv/querycoord/kv_catalog_test.go +++ b/internal/metastore/kv/querycoord/kv_catalog_test.go @@ -4,10 +4,13 @@ import ( "sort" "testing" + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/kv/mocks" "github.com/milvus-io/milvus/internal/proto/querypb" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -199,6 +202,49 @@ func (suite *CatalogTestSuite) TestResourceGroup() { suite.Equal([]int64{4, 5}, groups[1].GetNodes()) } +func (suite *CatalogTestSuite) TestCollectionTarget() { + suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ + CollectionID: 1, + Version: 1, + }) + suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ + CollectionID: 2, + Version: 2, + }) + suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ + CollectionID: 3, + Version: 3, + }) + suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ + CollectionID: 1, + Version: 4, + }) + suite.catalog.RemoveCollectionTarget(2) + + targets, err := suite.catalog.GetCollectionTargets() + suite.NoError(err) + suite.Len(targets, 2) + suite.Equal(int64(4), targets[1].Version) + suite.Equal(int64(3), targets[3].Version) + + // test access meta store failed + mockStore := mocks.NewMetaKv(suite.T()) + mockErr := errors.New("failed to access etcd") + mockStore.EXPECT().Save(mock.Anything, mock.Anything).Return(mockErr) + mockStore.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, mockErr) + + suite.catalog.cli = mockStore + err = suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{}) + suite.ErrorIs(err, mockErr) + + _, err = suite.catalog.GetCollectionTargets() + suite.ErrorIs(err, mockErr) + + // test invalid message + err = suite.catalog.SaveCollectionTarget(nil) + suite.Error(err) +} + func (suite *CatalogTestSuite) TestLoadRelease() { // TODO(sunby): add ut } diff --git a/internal/metastore/mocks/mock_querycoord_catalog.go b/internal/metastore/mocks/mock_querycoord_catalog.go index 00a4043432..98dfaaf9d7 100644 --- a/internal/metastore/mocks/mock_querycoord_catalog.go +++ b/internal/metastore/mocks/mock_querycoord_catalog.go @@ -20,6 +20,59 @@ func (_m *QueryCoordCatalog) EXPECT() *QueryCoordCatalog_Expecter { return &QueryCoordCatalog_Expecter{mock: &_m.Mock} } +// GetCollectionTargets provides a mock function with given fields: +func (_m *QueryCoordCatalog) GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) { + ret := _m.Called() + + var r0 map[int64]*querypb.CollectionTarget + var r1 error + if rf, ok := ret.Get(0).(func() (map[int64]*querypb.CollectionTarget, error)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() map[int64]*querypb.CollectionTarget); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*querypb.CollectionTarget) + } + } + + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// QueryCoordCatalog_GetCollectionTargets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCollectionTargets' +type QueryCoordCatalog_GetCollectionTargets_Call struct { + *mock.Call +} + +// GetCollectionTargets is a helper method to define mock.On call +func (_e *QueryCoordCatalog_Expecter) GetCollectionTargets() *QueryCoordCatalog_GetCollectionTargets_Call { + return &QueryCoordCatalog_GetCollectionTargets_Call{Call: _e.mock.On("GetCollectionTargets")} +} + +func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Run(run func()) *QueryCoordCatalog_GetCollectionTargets_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *QueryCoordCatalog_GetCollectionTargets_Call) Return(_a0 map[int64]*querypb.CollectionTarget, _a1 error) *QueryCoordCatalog_GetCollectionTargets_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *QueryCoordCatalog_GetCollectionTargets_Call) RunAndReturn(run func() (map[int64]*querypb.CollectionTarget, error)) *QueryCoordCatalog_GetCollectionTargets_Call { + _c.Call.Return(run) + return _c +} + // GetCollections provides a mock function with given fields: func (_m *QueryCoordCatalog) GetCollections() ([]*querypb.CollectionLoadInfo, error) { ret := _m.Called() @@ -416,6 +469,48 @@ func (_c *QueryCoordCatalog_ReleaseReplicas_Call) RunAndReturn(run func(int64) e return _c } +// RemoveCollectionTarget provides a mock function with given fields: collectionID +func (_m *QueryCoordCatalog) RemoveCollectionTarget(collectionID int64) error { + ret := _m.Called(collectionID) + + var r0 error + if rf, ok := ret.Get(0).(func(int64) error); ok { + r0 = rf(collectionID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// QueryCoordCatalog_RemoveCollectionTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCollectionTarget' +type QueryCoordCatalog_RemoveCollectionTarget_Call struct { + *mock.Call +} + +// RemoveCollectionTarget is a helper method to define mock.On call +// - collectionID int64 +func (_e *QueryCoordCatalog_Expecter) RemoveCollectionTarget(collectionID interface{}) *QueryCoordCatalog_RemoveCollectionTarget_Call { + return &QueryCoordCatalog_RemoveCollectionTarget_Call{Call: _e.mock.On("RemoveCollectionTarget", collectionID)} +} + +func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Run(run func(collectionID int64)) *QueryCoordCatalog_RemoveCollectionTarget_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(int64)) + }) + return _c +} + +func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) Return(_a0 error) *QueryCoordCatalog_RemoveCollectionTarget_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *QueryCoordCatalog_RemoveCollectionTarget_Call) RunAndReturn(run func(int64) error) *QueryCoordCatalog_RemoveCollectionTarget_Call { + _c.Call.Return(run) + return _c +} + // RemoveResourceGroup provides a mock function with given fields: rgName func (_m *QueryCoordCatalog) RemoveResourceGroup(rgName string) error { ret := _m.Called(rgName) @@ -515,6 +610,48 @@ func (_c *QueryCoordCatalog_SaveCollection_Call) RunAndReturn(run func(*querypb. return _c } +// SaveCollectionTarget provides a mock function with given fields: target +func (_m *QueryCoordCatalog) SaveCollectionTarget(target *querypb.CollectionTarget) error { + ret := _m.Called(target) + + var r0 error + if rf, ok := ret.Get(0).(func(*querypb.CollectionTarget) error); ok { + r0 = rf(target) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// QueryCoordCatalog_SaveCollectionTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCollectionTarget' +type QueryCoordCatalog_SaveCollectionTarget_Call struct { + *mock.Call +} + +// SaveCollectionTarget is a helper method to define mock.On call +// - target *querypb.CollectionTarget +func (_e *QueryCoordCatalog_Expecter) SaveCollectionTarget(target interface{}) *QueryCoordCatalog_SaveCollectionTarget_Call { + return &QueryCoordCatalog_SaveCollectionTarget_Call{Call: _e.mock.On("SaveCollectionTarget", target)} +} + +func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) Run(run func(target *querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTarget_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*querypb.CollectionTarget)) + }) + return _c +} + +func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) Return(_a0 error) *QueryCoordCatalog_SaveCollectionTarget_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) RunAndReturn(run func(*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTarget_Call { + _c.Call.Return(run) + return _c +} + // SavePartition provides a mock function with given fields: info func (_m *QueryCoordCatalog) SavePartition(info ...*querypb.PartitionLoadInfo) error { _va := make([]interface{}, len(info)) diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index f77ce1c604..df10e85f75 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -767,3 +767,27 @@ message CheckerInfo { bool activated = 3; bool found = 4; } + +message SegmentTarget { + int64 ID = 1; + data.SegmentLevel level = 2; +} + +message PartitionTarget { + int64 partitionID = 1; + repeated SegmentTarget segments = 2; +} + +message ChannelTarget { + string channelName = 1; + repeated int64 dropped_segmentIDs = 2; + repeated int64 growing_segmentIDs = 3; + repeated PartitionTarget partition_targets = 4; + msg.MsgPosition seek_position = 5; +} + +message CollectionTarget { + int64 collectionID = 1; + repeated ChannelTarget Channel_targets = 2; + int64 version = 3; +} diff --git a/internal/querycoordv2/meta/target.go b/internal/querycoordv2/meta/target.go index 2893f1636a..b7fc06930a 100644 --- a/internal/querycoordv2/meta/target.go +++ b/internal/querycoordv2/meta/target.go @@ -22,6 +22,7 @@ import ( "github.com/samber/lo" "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/querypb" ) // CollectionTarget collection target is immutable, @@ -39,6 +40,89 @@ func NewCollectionTarget(segments map[int64]*datapb.SegmentInfo, dmChannels map[ } } +func FromPbCollectionTarget(target *querypb.CollectionTarget) *CollectionTarget { + segments := make(map[int64]*datapb.SegmentInfo) + dmChannels := make(map[string]*DmChannel) + + for _, t := range target.GetChannelTargets() { + for _, partition := range t.GetPartitionTargets() { + for _, segment := range partition.GetSegments() { + segments[segment.GetID()] = &datapb.SegmentInfo{ + ID: segment.GetID(), + Level: segment.GetLevel(), + CollectionID: target.GetCollectionID(), + PartitionID: partition.GetPartitionID(), + InsertChannel: t.GetChannelName(), + } + } + } + dmChannels[t.GetChannelName()] = &DmChannel{ + VchannelInfo: &datapb.VchannelInfo{ + CollectionID: target.GetCollectionID(), + ChannelName: t.GetChannelName(), + SeekPosition: t.GetSeekPosition(), + UnflushedSegmentIds: t.GetGrowingSegmentIDs(), + FlushedSegmentIds: lo.Keys(segments), + DroppedSegmentIds: t.GetDroppedSegmentIDs(), + }, + } + } + + return NewCollectionTarget(segments, dmChannels) +} + +func (p *CollectionTarget) toPbMsg() *querypb.CollectionTarget { + if len(p.dmChannels) == 0 { + return &querypb.CollectionTarget{} + } + + channelSegments := make(map[string][]*datapb.SegmentInfo) + for _, s := range p.segments { + if _, ok := channelSegments[s.GetInsertChannel()]; !ok { + channelSegments[s.GetInsertChannel()] = make([]*datapb.SegmentInfo, 0) + } + channelSegments[s.GetInsertChannel()] = append(channelSegments[s.GetInsertChannel()], s) + } + + collectionID := int64(-1) + channelTargets := make(map[string]*querypb.ChannelTarget, 0) + for _, channel := range p.dmChannels { + collectionID = channel.GetCollectionID() + partitionTargets := make(map[int64]*querypb.PartitionTarget) + if infos, ok := channelSegments[channel.GetChannelName()]; ok { + for _, info := range infos { + partitionTarget, ok := partitionTargets[info.GetPartitionID()] + if !ok { + partitionTarget = &querypb.PartitionTarget{ + PartitionID: info.PartitionID, + Segments: make([]*querypb.SegmentTarget, 0), + } + partitionTargets[info.GetPartitionID()] = partitionTarget + } + + partitionTarget.Segments = append(partitionTarget.Segments, &querypb.SegmentTarget{ + ID: info.GetID(), + Level: info.GetLevel(), + }) + } + } + + channelTargets[channel.GetChannelName()] = &querypb.ChannelTarget{ + ChannelName: channel.GetChannelName(), + SeekPosition: channel.GetSeekPosition(), + GrowingSegmentIDs: channel.GetUnflushedSegmentIds(), + DroppedSegmentIDs: channel.GetDroppedSegmentIds(), + PartitionTargets: lo.Values(partitionTargets), + } + } + + return &querypb.CollectionTarget{ + CollectionID: collectionID, + ChannelTargets: lo.Values(channelTargets), + Version: p.version, + } +} + func (p *CollectionTarget) GetAllSegments() map[int64]*datapb.SegmentInfo { return p.segments } diff --git a/internal/querycoordv2/meta/target_manager.go b/internal/querycoordv2/meta/target_manager.go index 4813b335e2..2553129570 100644 --- a/internal/querycoordv2/meta/target_manager.go +++ b/internal/querycoordv2/meta/target_manager.go @@ -25,6 +25,7 @@ import ( "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -528,3 +529,46 @@ func (mgr *TargetManager) IsNextTargetExist(collectionID int64) bool { return len(newChannels) > 0 } + +func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog) { + mgr.rwMutex.Lock() + defer mgr.rwMutex.Unlock() + if mgr.current != nil { + for id, target := range mgr.current.collectionTargetMap { + if err := catalog.SaveCollectionTarget(target.toPbMsg()); err != nil { + log.Warn("failed to save current target for collection", zap.Int64("collectionID", id), zap.Error(err)) + } else { + log.Warn("succeed to save current target for collection", zap.Int64("collectionID", id)) + } + } + } +} + +func (mgr *TargetManager) Recover(catalog metastore.QueryCoordCatalog) error { + mgr.rwMutex.Lock() + defer mgr.rwMutex.Unlock() + + targets, err := catalog.GetCollectionTargets() + if err != nil { + log.Warn("failed to recover collection target from etcd", zap.Error(err)) + return err + } + + for _, t := range targets { + newTarget := FromPbCollectionTarget(t) + mgr.current.updateCollectionTarget(t.GetCollectionID(), newTarget) + log.Info("recover current target for collection", + zap.Int64("collectionID", t.GetCollectionID()), + zap.Strings("channels", newTarget.GetAllDmChannelNames()), + zap.Int("segmentNum", len(newTarget.GetAllSegmentIDs())), + ) + + // clear target info in meta store + err := catalog.RemoveCollectionTarget(t.GetCollectionID()) + if err != nil { + log.Warn("failed to clear collection target from etcd", zap.Error(err)) + } + } + + return nil +} diff --git a/internal/querycoordv2/meta/target_manager_test.go b/internal/querycoordv2/meta/target_manager_test.go index 67cbcd7f00..5c8a152b33 100644 --- a/internal/querycoordv2/meta/target_manager_test.go +++ b/internal/querycoordv2/meta/target_manager_test.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/internal/kv" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/metastore/kv/querycoord" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -53,9 +54,10 @@ type TargetManagerSuite struct { allChannels []string allSegments []int64 - kv kv.MetaKv - meta *Meta - broker *MockBroker + kv kv.MetaKv + catalog metastore.QueryCoordCatalog + meta *Meta + broker *MockBroker // Test object mgr *TargetManager } @@ -110,9 +112,9 @@ func (suite *TargetManagerSuite) SetupTest() { suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) // meta - store := querycoord.NewCatalog(suite.kv) + suite.catalog = querycoord.NewCatalog(suite.kv) idAllocator := RandomIncrementIDAllocator() - suite.meta = NewMeta(idAllocator, store, session.NewNodeManager()) + suite.meta = NewMeta(idAllocator, suite.catalog, session.NewNodeManager()) suite.broker = NewMockBroker(suite.T()) suite.mgr = NewTargetManager(suite.broker, suite.meta) @@ -547,6 +549,75 @@ func (suite *TargetManagerSuite) TestGetTarget() { } } +func (suite *TargetManagerSuite) TestRecover() { + collectionID := int64(1003) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, NextTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, NextTarget)) + suite.assertSegments([]int64{}, suite.mgr.GetSealedSegmentsByCollection(collectionID, CurrentTarget)) + suite.assertChannels([]string{}, suite.mgr.GetDmChannelsByCollection(collectionID, CurrentTarget)) + + suite.meta.PutCollection(&Collection{ + CollectionLoadInfo: &querypb.CollectionLoadInfo{ + CollectionID: collectionID, + ReplicaNumber: 1, + }, + }) + suite.meta.PutPartition(&Partition{ + PartitionLoadInfo: &querypb.PartitionLoadInfo{ + CollectionID: collectionID, + PartitionID: 1, + }, + }) + + nextTargetChannels := []*datapb.VchannelInfo{ + { + CollectionID: collectionID, + ChannelName: "channel-1", + UnflushedSegmentIds: []int64{1, 2, 3, 4}, + DroppedSegmentIds: []int64{11, 22, 33}, + }, + { + CollectionID: collectionID, + ChannelName: "channel-2", + UnflushedSegmentIds: []int64{5}, + }, + } + + nextTargetSegments := []*datapb.SegmentInfo{ + { + ID: 11, + PartitionID: 1, + InsertChannel: "channel-1", + }, + { + ID: 12, + PartitionID: 1, + InsertChannel: "channel-2", + }, + } + + suite.broker.EXPECT().GetRecoveryInfoV2(mock.Anything, collectionID).Return(nextTargetChannels, nextTargetSegments, nil) + suite.mgr.UpdateCollectionNextTarget(collectionID) + suite.mgr.UpdateCollectionCurrentTarget(collectionID) + + suite.mgr.SaveCurrentTarget(suite.catalog) + + // clear target in memory + suite.mgr.current.removeCollectionTarget(collectionID) + // try to recover + suite.mgr.Recover(suite.catalog) + + target := suite.mgr.current.getCollectionTarget(collectionID) + suite.NotNil(target) + suite.Len(target.GetAllDmChannelNames(), 2) + suite.Len(target.GetAllSegmentIDs(), 2) + + // after recover, target info should be cleaned up + targets, err := suite.catalog.GetCollectionTargets() + suite.NoError(err) + suite.Len(targets, 0) +} + func TestTargetManager(t *testing.T) { suite.Run(t, new(TargetManagerSuite)) } diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index a548c17dca..69ffb9efd3 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -349,6 +349,11 @@ func (s *Server) initMeta() error { LeaderViewManager: meta.NewLeaderViewManager(), } s.targetMgr = meta.NewTargetManager(s.broker, s.meta) + err = s.targetMgr.Recover(s.store) + if err != nil { + log.Warn("failed to recover collection targets", zap.Error(err)) + } + log.Info("QueryCoord server initMeta done", zap.Duration("duration", record.ElapseSpan())) return nil } @@ -454,6 +459,11 @@ func (s *Server) startServerLoop() { } func (s *Server) Stop() error { + // save target to meta store, after querycoord restart, make it fast to recover current target + if s.targetMgr != nil { + s.targetMgr.SaveCurrentTarget(s.store) + } + // FOLLOW the dependence graph: // job scheduler -> checker controller -> task scheduler -> dist controller -> cluster -> session // observers -> dist controller diff --git a/tests/integration/minicluster_v2.go b/tests/integration/minicluster_v2.go index 7b722f0511..71386834fd 100644 --- a/tests/integration/minicluster_v2.go +++ b/tests/integration/minicluster_v2.go @@ -164,7 +164,7 @@ func StartMiniClusterV2(ctx context.Context, opts ...OptionV2) (*MiniClusterV2, etcdCli: cluster.EtcdCli, } - ports, err := GetAvailablePorts(7) + ports, err := cluster.GetAvailablePorts(7) if err != nil { return nil, err } @@ -421,10 +421,10 @@ func (cluster *MiniClusterV2) GetFactory() dependency.Factory { return cluster.factory } -func GetAvailablePorts(n int) ([]int, error) { +func (cluster *MiniClusterV2) GetAvailablePorts(n int) ([]int, error) { ports := make([]int, n) for i := range ports { - port, err := GetAvailablePort() + port, err := cluster.GetAvailablePort() if err != nil { return nil, err } @@ -433,7 +433,7 @@ func GetAvailablePorts(n int) ([]int, error) { return ports, nil } -func GetAvailablePort() (int, error) { +func (cluster *MiniClusterV2) GetAvailablePort() (int, error) { address, err := net.ResolveTCPAddr("tcp", fmt.Sprintf("%s:0", "0.0.0.0")) if err != nil { return 0, err diff --git a/tests/integration/target/target_test.go b/tests/integration/target/target_test.go new file mode 100644 index 0000000000..e6b739d69c --- /dev/null +++ b/tests/integration/target/target_test.go @@ -0,0 +1,221 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package balance + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metric" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/tests/integration" +) + +const ( + dim = 128 + dbName = "" +) + +type TargetTestSuit struct { + integration.MiniClusterSuite +} + +func (s *TargetTestSuit) SetupSuite() { + paramtable.Init() + paramtable.Get().Save(paramtable.Get().QueryCoordCfg.BalanceCheckInterval.Key, "1000") + paramtable.Get().Save(paramtable.Get().QueryNodeCfg.GracefulStopTimeout.Key, "1") + + s.Require().NoError(s.SetupEmbedEtcd()) +} + +func (s *TargetTestSuit) initCollection(collectionName string, replica int, channelNum int, segmentNum int, segmentRowNum int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := integration.ConstructSchema(collectionName, dim, true) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := s.Cluster.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: int32(channelNum), + }) + s.NoError(err) + s.True(merr.Ok(createCollectionStatus)) + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := s.Cluster.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.Status)) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + for i := 0; i < segmentNum; i++ { + s.insertToCollection(ctx, dbName, collectionName, segmentRowNum, dim) + } + + // create index + createIndexStatus, err := s.Cluster.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2), + }) + s.NoError(err) + s.True(merr.Ok(createIndexStatus)) + s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) + + for i := 1; i < replica; i++ { + s.Cluster.AddQueryNode() + } + + // load + loadStatus, err := s.Cluster.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + ReplicaNumber: int32(replica), + }) + s.NoError(err) + s.Equal(commonpb.ErrorCode_Success, loadStatus.GetErrorCode()) + s.True(merr.Ok(loadStatus)) + s.WaitForLoad(ctx, collectionName) + log.Info("initCollection Done") +} + +func (s *TargetTestSuit) insertToCollection(ctx context.Context, dbName string, collectionName string, rowCount int, dim int) { + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowCount, dim) + hashKeys := integration.GenerateHashKeys(rowCount) + insertResult, err := s.Cluster.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowCount), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.Status)) + + // flush + flushResp, err := s.Cluster.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) +} + +func (s *TargetTestSuit) TestQueryCoordRestart() { + name := "test_balance_" + funcutil.GenRandomStr() + + // generate 20 small segments here, which will make segment list changes by time + s.initCollection(name, 1, 2, 2, 2000) + + ctx := context.Background() + + info, err := s.Cluster.Proxy.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{ + Base: commonpbutil.NewMsgBase(), + CollectionName: name, + }) + s.NoError(err) + s.True(merr.Ok(info.GetStatus())) + collectionID := info.GetCollectionID() + + // trigger old coord stop + s.Cluster.QueryCoord.Stop() + + // keep insert, make segment list change every 3 seconds + closeInsertCh := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-closeInsertCh: + log.Info("insert to collection finished") + return + case <-time.After(time.Second): + s.insertToCollection(ctx, dbName, name, 2000, dim) + log.Info("insert 2000 rows to collection finished") + } + } + }() + + // sleep 30s, wait new flushed segment generated + time.Sleep(30 * time.Second) + + port, err := s.Cluster.GetAvailablePort() + s.NoError(err) + paramtable.Get().Save(paramtable.Get().QueryCoordGrpcServerCfg.Port.Key, fmt.Sprint(port)) + + // start a new QC + newQC, err := grpcquerycoord.NewServer(ctx, s.Cluster.GetFactory()) + s.NoError(err) + go func() { + err := newQC.Run() + s.NoError(err) + }() + s.Cluster.QueryCoord = newQC + + // after new QC become Active, expected the new target is ready immediately, and get shard leader success + s.Eventually(func() bool { + resp, err := newQC.CheckHealth(ctx, &milvuspb.CheckHealthRequest{}) + s.NoError(err) + if resp.IsHealthy { + resp, err := s.Cluster.QueryCoord.GetShardLeaders(ctx, &querypb.GetShardLeadersRequest{ + Base: commonpbutil.NewMsgBase(), + CollectionID: collectionID, + }) + log.Info("resp", zap.Any("status", resp.GetStatus()), zap.Any("shards", resp.Shards)) + s.NoError(err) + s.True(merr.Ok(resp.GetStatus())) + + return len(resp.Shards) == 2 + } + return false + }, 60*time.Second, 1*time.Second) + + close(closeInsertCh) + wg.Wait() +} + +func TestTarget(t *testing.T) { + suite.Run(t, new(TargetTestSuit)) +}