enhance: Save collection targets by batches (#31616)

See also #28491 #31240

When colleciton number is large, querycoord saves collection target one
by one, which is slow and may block querycoord exits.

In local run, 500 collections scenario may lead to about 40 seconds
saving collection targets.

This PR changes the `SaveCollectionTarget` interface into batch one and
organizes the collection in 16 per bundle batches to accelerate this
procedure.

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2024-03-27 00:09:08 +08:00 committed by GitHub
parent 248c923e59
commit 8e5865f630
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 102 additions and 54 deletions

View File

@ -165,7 +165,7 @@ type QueryCoordCatalog interface {
RemoveResourceGroup(rgName string) error RemoveResourceGroup(rgName string) error
GetResourceGroups() ([]*querypb.ResourceGroup, error) GetResourceGroups() ([]*querypb.ResourceGroup, error)
SaveCollectionTarget(target *querypb.CollectionTarget) error SaveCollectionTargets(target ...*querypb.CollectionTarget) error
RemoveCollectionTarget(collectionID int64) error RemoveCollectionTarget(collectionID int64) error
GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error) GetCollectionTargets() (map[int64]*querypb.CollectionTarget, error)
} }

View File

@ -241,16 +241,21 @@ func (s Catalog) ReleaseReplica(collection, replica int64) error {
return s.cli.Remove(key) return s.cli.Remove(key)
} }
func (s Catalog) SaveCollectionTarget(target *querypb.CollectionTarget) error { func (s Catalog) SaveCollectionTargets(targets ...*querypb.CollectionTarget) error {
k := encodeCollectionTargetKey(target.GetCollectionID()) kvs := make(map[string]string)
v, err := proto.Marshal(target) for _, target := range targets {
if err != nil { k := encodeCollectionTargetKey(target.GetCollectionID())
return err v, err := proto.Marshal(target)
if err != nil {
return err
}
var compressed bytes.Buffer
compressor.ZstdCompress(bytes.NewReader(v), io.Writer(&compressed), zstd.WithEncoderLevel(zstd.SpeedBetterCompression))
kvs[k] = compressed.String()
} }
// to reduce the target size, we do compress before write to etcd // to reduce the target size, we do compress before write to etcd
var compressed bytes.Buffer err := s.cli.MultiSave(kvs)
compressor.ZstdCompress(bytes.NewReader(v), io.Writer(&compressed), zstd.WithEncoderLevel(zstd.SpeedBetterCompression))
err = s.cli.Save(k, compressed.String())
if err != nil { if err != nil {
return err return err
} }

View File

@ -203,22 +203,22 @@ func (suite *CatalogTestSuite) TestResourceGroup() {
} }
func (suite *CatalogTestSuite) TestCollectionTarget() { func (suite *CatalogTestSuite) TestCollectionTarget() {
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{
CollectionID: 1, CollectionID: 1,
Version: 1, Version: 1,
}) },
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ &querypb.CollectionTarget{
CollectionID: 2, CollectionID: 2,
Version: 2, Version: 2,
}) },
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ &querypb.CollectionTarget{
CollectionID: 3, CollectionID: 3,
Version: 3, Version: 3,
}) },
suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{ &querypb.CollectionTarget{
CollectionID: 1, CollectionID: 1,
Version: 4, Version: 4,
}) })
suite.catalog.RemoveCollectionTarget(2) suite.catalog.RemoveCollectionTarget(2)
targets, err := suite.catalog.GetCollectionTargets() targets, err := suite.catalog.GetCollectionTargets()
@ -230,18 +230,18 @@ func (suite *CatalogTestSuite) TestCollectionTarget() {
// test access meta store failed // test access meta store failed
mockStore := mocks.NewMetaKv(suite.T()) mockStore := mocks.NewMetaKv(suite.T())
mockErr := errors.New("failed to access etcd") mockErr := errors.New("failed to access etcd")
mockStore.EXPECT().Save(mock.Anything, mock.Anything).Return(mockErr) mockStore.EXPECT().MultiSave(mock.Anything).Return(mockErr)
mockStore.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, mockErr) mockStore.EXPECT().LoadWithPrefix(mock.Anything).Return(nil, nil, mockErr)
suite.catalog.cli = mockStore suite.catalog.cli = mockStore
err = suite.catalog.SaveCollectionTarget(&querypb.CollectionTarget{}) err = suite.catalog.SaveCollectionTargets(&querypb.CollectionTarget{})
suite.ErrorIs(err, mockErr) suite.ErrorIs(err, mockErr)
_, err = suite.catalog.GetCollectionTargets() _, err = suite.catalog.GetCollectionTargets()
suite.ErrorIs(err, mockErr) suite.ErrorIs(err, mockErr)
// test invalid message // test invalid message
err = suite.catalog.SaveCollectionTarget(nil) err = suite.catalog.SaveCollectionTargets(nil)
suite.Error(err) suite.Error(err)
} }

View File

@ -610,13 +610,19 @@ func (_c *QueryCoordCatalog_SaveCollection_Call) RunAndReturn(run func(*querypb.
return _c return _c
} }
// SaveCollectionTarget provides a mock function with given fields: target // SaveCollectionTargets provides a mock function with given fields: target
func (_m *QueryCoordCatalog) SaveCollectionTarget(target *querypb.CollectionTarget) error { func (_m *QueryCoordCatalog) SaveCollectionTargets(target ...*querypb.CollectionTarget) error {
ret := _m.Called(target) _va := make([]interface{}, len(target))
for _i := range target {
_va[_i] = target[_i]
}
var _ca []interface{}
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(*querypb.CollectionTarget) error); ok { if rf, ok := ret.Get(0).(func(...*querypb.CollectionTarget) error); ok {
r0 = rf(target) r0 = rf(target...)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -624,30 +630,37 @@ func (_m *QueryCoordCatalog) SaveCollectionTarget(target *querypb.CollectionTarg
return r0 return r0
} }
// QueryCoordCatalog_SaveCollectionTarget_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCollectionTarget' // QueryCoordCatalog_SaveCollectionTargets_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SaveCollectionTargets'
type QueryCoordCatalog_SaveCollectionTarget_Call struct { type QueryCoordCatalog_SaveCollectionTargets_Call struct {
*mock.Call *mock.Call
} }
// SaveCollectionTarget is a helper method to define mock.On call // SaveCollectionTargets is a helper method to define mock.On call
// - target *querypb.CollectionTarget // - target ...*querypb.CollectionTarget
func (_e *QueryCoordCatalog_Expecter) SaveCollectionTarget(target interface{}) *QueryCoordCatalog_SaveCollectionTarget_Call { func (_e *QueryCoordCatalog_Expecter) SaveCollectionTargets(target ...interface{}) *QueryCoordCatalog_SaveCollectionTargets_Call {
return &QueryCoordCatalog_SaveCollectionTarget_Call{Call: _e.mock.On("SaveCollectionTarget", target)} return &QueryCoordCatalog_SaveCollectionTargets_Call{Call: _e.mock.On("SaveCollectionTargets",
append([]interface{}{}, target...)...)}
} }
func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) Run(run func(target *querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTarget_Call { func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Run(run func(target ...*querypb.CollectionTarget)) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*querypb.CollectionTarget)) variadicArgs := make([]*querypb.CollectionTarget, len(args)-0)
for i, a := range args[0:] {
if a != nil {
variadicArgs[i] = a.(*querypb.CollectionTarget)
}
}
run(variadicArgs...)
}) })
return _c return _c
} }
func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) Return(_a0 error) *QueryCoordCatalog_SaveCollectionTarget_Call { func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) Return(_a0 error) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Return(_a0) _c.Call.Return(_a0)
return _c return _c
} }
func (_c *QueryCoordCatalog_SaveCollectionTarget_Call) RunAndReturn(run func(*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTarget_Call { func (_c *QueryCoordCatalog_SaveCollectionTargets_Call) RunAndReturn(run func(...*querypb.CollectionTarget) error) *QueryCoordCatalog_SaveCollectionTargets_Call {
_c.Call.Return(run) _c.Call.Return(run)
return _c return _c
} }

View File

@ -19,6 +19,7 @@ package meta
import ( import (
"context" "context"
"fmt" "fmt"
"runtime"
"sync" "sync"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
@ -28,9 +29,11 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/metastore"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/conc"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/retry" "github.com/milvus-io/milvus/pkg/util/retry"
@ -594,13 +597,38 @@ func (mgr *TargetManager) SaveCurrentTarget(catalog metastore.QueryCoordCatalog)
mgr.rwMutex.Lock() mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock() defer mgr.rwMutex.Unlock()
if mgr.current != nil { if mgr.current != nil {
// use pool here to control maximal writer used by save target
pool := conc.NewPool[any](runtime.GOMAXPROCS(0) * 2)
// use batch write in case of the number of collections is large
batchSize := 16
var wg sync.WaitGroup
submit := func(tasks []typeutil.Pair[int64, *querypb.CollectionTarget]) {
wg.Add(1)
pool.Submit(func() (any, error) {
defer wg.Done()
ids := lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) int64 { return p.A })
if err := catalog.SaveCollectionTargets(lo.Map(tasks, func(p typeutil.Pair[int64, *querypb.CollectionTarget], _ int) *querypb.CollectionTarget {
return p.B
})...); err != nil {
log.Warn("failed to save current target for collection", zap.Int64s("collectionIDs", ids), zap.Error(err))
} else {
log.Info("succeed to save current target for collection", zap.Int64s("collectionIDs", ids))
}
return nil, nil
})
}
tasks := make([]typeutil.Pair[int64, *querypb.CollectionTarget], 0, batchSize)
for id, target := range mgr.current.collectionTargetMap { for id, target := range mgr.current.collectionTargetMap {
if err := catalog.SaveCollectionTarget(target.toPbMsg()); err != nil { tasks = append(tasks, typeutil.NewPair(id, target.toPbMsg()))
log.Warn("failed to save current target for collection", zap.Int64("collectionID", id), zap.Error(err)) if len(tasks) >= batchSize {
} else { submit(tasks)
log.Warn("succeed to save current target for collection", zap.Int64("collectionID", id)) tasks = make([]typeutil.Pair[int64, *querypb.CollectionTarget], 0, batchSize)
} }
} }
if len(tasks) > 0 {
submit(tasks)
}
wg.Wait()
} }
} }

View File

@ -26,6 +26,7 @@ import (
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
@ -146,8 +147,8 @@ func (s *CoordSwitchSuite) checkCollections() bool {
TimeStamp: 0, // means now TimeStamp: 0, // means now
} }
resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req) resp, err := s.Cluster.Proxy.ShowCollections(context.TODO(), req)
s.NoError(err) s.Require().NoError(merr.CheckRPCCall(resp, err))
s.Equal(len(resp.CollectionIds), numCollections) s.Require().Equal(len(resp.CollectionIds), numCollections)
notLoaded := 0 notLoaded := 0
loaded := 0 loaded := 0
for _, name := range resp.CollectionNames { for _, name := range resp.CollectionNames {
@ -181,7 +182,7 @@ func (s *CoordSwitchSuite) search(collectionName string, dim int) {
GuaranteeTimestamp: 0, GuaranteeTimestamp: 0,
} }
queryResult, err := c.Proxy.Query(context.TODO(), queryReq) queryResult, err := c.Proxy.Query(context.TODO(), queryReq)
s.NoError(err) s.Require().NoError(merr.CheckRPCCall(queryResult, err))
s.Equal(len(queryResult.FieldsData), 1) s.Equal(len(queryResult.FieldsData), 1)
numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0] numEntities := queryResult.FieldsData[0].GetScalars().GetLongData().Data[0]
s.Equal(numEntities, int64(rowsPerCollection)) s.Equal(numEntities, int64(rowsPerCollection))
@ -198,10 +199,9 @@ func (s *CoordSwitchSuite) search(collectionName string, dim int) {
searchReq := integration.ConstructSearchRequest("", collectionName, expr, searchReq := integration.ConstructSearchRequest("", collectionName, expr,
integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal) integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.IP, params, nq, dim, topk, roundDecimal)
searchResult, _ := c.Proxy.Search(context.TODO(), searchReq) searchResult, err := c.Proxy.Search(context.TODO(), searchReq)
err = merr.Error(searchResult.GetStatus()) s.NoError(merr.CheckRPCCall(searchResult, err))
s.NoError(err)
} }
func (s *CoordSwitchSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart, dim int, wg *sync.WaitGroup) { func (s *CoordSwitchSuite) insertBatchCollections(prefix string, collectionBatchSize, idxStart, dim int, wg *sync.WaitGroup) {
@ -229,7 +229,7 @@ func (s *CoordSwitchSuite) setupData() {
} }
wg.Wait() wg.Wait()
log.Info("=========================Data injection finished=========================") log.Info("=========================Data injection finished=========================")
s.checkCollections() s.Require().True(s.checkCollections())
log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName)) log.Info(fmt.Sprintf("=========================start to search %s=========================", searchName))
s.search(searchName, Dim) s.search(searchName, Dim)
log.Info("=========================Search finished=========================") log.Info("=========================Search finished=========================")
@ -238,11 +238,13 @@ func (s *CoordSwitchSuite) setupData() {
func (s *CoordSwitchSuite) switchCoord() float64 { func (s *CoordSwitchSuite) switchCoord() float64 {
var err error var err error
c := s.Cluster c := s.Cluster
start := time.Now()
log.Info("=========================Stopping Coordinators========================")
c.RootCoord.Stop() c.RootCoord.Stop()
c.DataCoord.Stop() c.DataCoord.Stop()
c.QueryCoord.Stop() c.QueryCoord.Stop()
log.Info("=========================Coordinators stopped=========================") log.Info("=========================Coordinators stopped=========================", zap.Duration("elapsed", time.Since(start)))
start := time.Now() start = time.Now()
c.RootCoord, err = grpcrootcoord.NewServer(context.TODO(), c.GetFactory()) c.RootCoord, err = grpcrootcoord.NewServer(context.TODO(), c.GetFactory())
s.NoError(err) s.NoError(err)