Fix load collection failed after drop partition (#24680)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-06-07 19:04:36 +08:00 committed by GitHub
parent a6dbcaeb7a
commit 89db828f71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 218 additions and 18 deletions

View File

@ -21,6 +21,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/cockroachdb/errors"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/samber/lo" "github.com/samber/lo"
"go.uber.org/zap" "go.uber.org/zap"
@ -126,6 +127,16 @@ func (m *CollectionManager) Recover(broker Broker) error {
} }
for _, collection := range collections { for _, collection := range collections {
// Dropped collection should be deprecated
_, err = broker.GetCollectionSchema(context.Background(), collection.GetCollectionID())
if errors.Is(err, merr.ErrCollectionNotFound) {
log.Info("skip dropped collection during recovery", zap.Int64("collection", collection.GetCollectionID()))
m.store.ReleaseCollection(collection.GetCollectionID())
continue
}
if err != nil {
return err
}
// Collections not loaded done should be deprecated // Collections not loaded done should be deprecated
if collection.GetStatus() != querypb.LoadStatus_Loaded || collection.GetReplicaNumber() <= 0 { if collection.GetStatus() != querypb.LoadStatus_Loaded || collection.GetReplicaNumber() <= 0 {
log.Info("skip recovery and release collection", log.Info("skip recovery and release collection",
@ -142,15 +153,37 @@ func (m *CollectionManager) Recover(broker Broker) error {
} }
for collection, partitions := range partitions { for collection, partitions := range partitions {
existPartitions, err := broker.GetPartitions(context.Background(), collection)
if errors.Is(err, merr.ErrCollectionNotFound) {
log.Info("skip dropped collection during recovery", zap.Int64("collection", collection))
m.store.ReleaseCollection(collection)
continue
}
if err != nil {
return err
}
omitPartitions := make([]int64, 0)
partitions = lo.Filter(partitions, func(partition *querypb.PartitionLoadInfo, _ int) bool {
if !lo.Contains(existPartitions, partition.GetPartitionID()) {
omitPartitions = append(omitPartitions, partition.GetPartitionID())
return false
}
return true
})
if len(omitPartitions) > 0 {
log.Info("skip dropped partitions during recovery",
zap.Int64("collection", collection), zap.Int64s("partitions", omitPartitions))
m.store.ReleasePartition(collection, omitPartitions...)
}
sawLoaded := false sawLoaded := false
for _, partition := range partitions { for _, partition := range partitions {
// Partitions not loaded done should be deprecated // Partitions not loaded done should be deprecated
if partition.GetStatus() != querypb.LoadStatus_Loaded || partition.GetReplicaNumber() <= 0 { if partition.GetStatus() != querypb.LoadStatus_Loaded {
log.Info("skip recovery and release partition", log.Info("skip recovery and release partition",
zap.Int64("collectionID", collection), zap.Int64("collectionID", collection),
zap.Int64("partitionID", partition.GetPartitionID()), zap.Int64("partitionID", partition.GetPartitionID()),
zap.String("status", partition.GetStatus().String()), zap.String("status", partition.GetStatus().String()),
zap.Int32("replicaNumber", partition.GetReplicaNumber()),
) )
m.store.ReleasePartition(collection, partition.GetPartitionID()) m.store.ReleasePartition(collection, partition.GetPartitionID())
continue continue

View File

@ -17,6 +17,7 @@
package meta package meta
import ( import (
"fmt"
"sort" "sort"
"testing" "testing"
"time" "time"
@ -30,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
. "github.com/milvus-io/milvus/internal/querycoordv2/params" . "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/merr"
) )
type CollectionManagerSuite struct { type CollectionManagerSuite struct {
@ -171,6 +173,13 @@ func (suite *CollectionManagerSuite) TestGet() {
func (suite *CollectionManagerSuite) TestUpdate() { func (suite *CollectionManagerSuite) TestUpdate() {
mgr := suite.mgr mgr := suite.mgr
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil)
for _, collection := range suite.collections {
if len(suite.partitions[collection]) > 0 {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
}
}
collections := mgr.GetAllCollections() collections := mgr.GetAllCollections()
partitions := mgr.GetAllPartitions() partitions := mgr.GetAllPartitions()
for _, collection := range collections { for _, collection := range collections {
@ -237,6 +246,11 @@ func (suite *CollectionManagerSuite) TestGetFieldIndex() {
func (suite *CollectionManagerSuite) TestRemove() { func (suite *CollectionManagerSuite) TestRemove() {
mgr := suite.mgr mgr := suite.mgr
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil)
for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
}
// Remove collections/partitions // Remove collections/partitions
for i, collectionID := range suite.collections { for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection { if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
@ -298,16 +312,89 @@ func (suite *CollectionManagerSuite) TestRemove() {
} }
} }
func (suite *CollectionManagerSuite) TestRecover() { func (suite *CollectionManagerSuite) TestRecover_normal() {
mgr := suite.mgr mgr := suite.mgr
// recover successfully
for _, collection := range suite.collections {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, collection).Return(nil, nil)
if len(suite.partitions[collection]) > 0 {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
}
}
suite.clearMemory() suite.clearMemory()
err := mgr.Recover(suite.broker) err := mgr.Recover(suite.broker)
suite.NoError(err) suite.NoError(err)
for i, collection := range suite.collections { for i, collection := range suite.collections {
exist := suite.colLoadPercent[i] == 100 exist := suite.colLoadPercent[i] == 100
suite.Equal(exist, mgr.Exist(collection)) suite.Equal(exist, mgr.Exist(collection))
if !exist {
continue
} }
for j, partitionID := range suite.partitions[collection] {
partition := mgr.GetPartition(partitionID)
exist = suite.parLoadPercent[collection][j] == 100
suite.Equal(exist, partition != nil)
}
}
}
func (suite *CollectionManagerSuite) TestRecover_with_dropped() {
mgr := suite.mgr
droppedCollection := int64(101)
droppedPartition := int64(13)
for _, collection := range suite.collections {
if collection == droppedCollection {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, collection).Return(nil, merr.ErrCollectionNotFound)
} else {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, collection).Return(nil, nil)
}
if len(suite.partitions[collection]) != 0 {
if collection == droppedCollection {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(nil, merr.ErrCollectionNotFound)
} else {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).
Return(lo.Filter(suite.partitions[collection], func(partition int64, _ int) bool {
return partition != droppedPartition
}), nil)
}
}
}
suite.clearMemory()
err := mgr.Recover(suite.broker)
suite.NoError(err)
for i, collection := range suite.collections {
exist := suite.colLoadPercent[i] == 100 && collection != droppedCollection
suite.Equal(exist, mgr.Exist(collection))
if !exist {
continue
}
for j, partitionID := range suite.partitions[collection] {
partition := mgr.GetPartition(partitionID)
exist = suite.parLoadPercent[collection][j] == 100 && partitionID != droppedPartition
suite.Equal(exist, partition != nil)
}
}
}
func (suite *CollectionManagerSuite) TestRecover_Failed() {
mockErr1 := fmt.Errorf("mock GetCollectionSchema err")
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, mockErr1)
suite.clearMemory()
err := suite.mgr.Recover(suite.broker)
suite.Error(err)
suite.ErrorIs(err, mockErr1)
mockErr2 := fmt.Errorf("mock GetPartitions err")
suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0]
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return(nil, mockErr2)
suite.clearMemory()
err = suite.mgr.Recover(suite.broker)
suite.Error(err)
suite.ErrorIs(err, mockErr2)
} }
func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() { func (suite *CollectionManagerSuite) TestUpdateLoadPercentage() {
@ -376,6 +463,13 @@ func (suite *CollectionManagerSuite) TestUpgradeRecover() {
suite.releaseAll() suite.releaseAll()
mgr := suite.mgr mgr := suite.mgr
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(nil, nil)
for _, collection := range suite.collections {
if len(suite.partitions[collection]) > 0 {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
}
}
// put old version of collections and partitions // put old version of collections and partitions
for i, collection := range suite.collections { for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded status := querypb.LoadStatus_Loaded

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/merr"
@ -75,10 +76,16 @@ func (broker *CoordinatorBroker) GetCollectionSchema(ctx context.Context, collec
), ),
CollectionID: collectionID, CollectionID: collectionID,
} }
resp, err := broker.rootCoord.DescribeCollectionInternal(ctx, req) resp, err := broker.rootCoord.DescribeCollection(ctx, req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
statusErr := common.NewStatusError(resp.Status.ErrorCode, resp.Status.Reason)
if common.IsCollectionNotExistError(statusErr) {
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
err = errors.New(resp.GetStatus().GetReason()) err = errors.New(resp.GetStatus().GetReason())
log.Error("failed to get collection schema", zap.Int64("collectionID", collectionID), zap.Error(err)) log.Error("failed to get collection schema", zap.Int64("collectionID", collectionID), zap.Error(err))
@ -96,12 +103,17 @@ func (broker *CoordinatorBroker) GetPartitions(ctx context.Context, collectionID
), ),
CollectionID: collectionID, CollectionID: collectionID,
} }
resp, err := broker.rootCoord.ShowPartitionsInternal(ctx, req) resp, err := broker.rootCoord.ShowPartitions(ctx, req)
if err != nil { if err != nil {
log.Warn("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err)) log.Warn("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err))
return nil, err return nil, err
} }
statusErr := common.NewStatusError(resp.Status.ErrorCode, resp.Status.Reason)
if common.IsCollectionNotExistError(statusErr) {
return nil, merr.WrapErrCollectionNotFound(collectionID)
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
err = errors.New(resp.GetStatus().GetReason()) err = errors.New(resp.GetStatus().GetReason())
log.Warn("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err)) log.Warn("showPartition failed", zap.Int64("collectionID", collectionID), zap.Error(err))

View File

@ -29,15 +29,16 @@ import (
"github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/pkg/util/merr"
) )
func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) { func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) {
t.Run("got error on DescribeCollectionInternal", func(t *testing.T) { t.Run("got error on DescribeCollection", func(t *testing.T) {
rootCoord := mocks.NewRootCoord(t) rootCoord := mocks.NewRootCoord(t)
rootCoord.On("DescribeCollectionInternal", rootCoord.On("DescribeCollection",
mock.Anything, mock.Anything,
mock.Anything, mock.Anything,
).Return(nil, errors.New("error mock DescribeCollectionInternal")) ).Return(nil, errors.New("error mock DescribeCollection"))
ctx := context.Background() ctx := context.Background()
broker := &CoordinatorBroker{rootCoord: rootCoord} broker := &CoordinatorBroker{rootCoord: rootCoord}
_, err := broker.GetCollectionSchema(ctx, 100) _, err := broker.GetCollectionSchema(ctx, 100)
@ -46,7 +47,7 @@ func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) {
t.Run("non-success code", func(t *testing.T) { t.Run("non-success code", func(t *testing.T) {
rootCoord := mocks.NewRootCoord(t) rootCoord := mocks.NewRootCoord(t)
rootCoord.On("DescribeCollectionInternal", rootCoord.On("DescribeCollection",
mock.Anything, mock.Anything,
mock.Anything, mock.Anything,
).Return(&milvuspb.DescribeCollectionResponse{ ).Return(&milvuspb.DescribeCollectionResponse{
@ -60,7 +61,7 @@ func TestCoordinatorBroker_GetCollectionSchema(t *testing.T) {
t.Run("normal case", func(t *testing.T) { t.Run("normal case", func(t *testing.T) {
rootCoord := mocks.NewRootCoord(t) rootCoord := mocks.NewRootCoord(t)
rootCoord.On("DescribeCollectionInternal", rootCoord.On("DescribeCollection",
mock.Anything, mock.Anything,
mock.Anything, mock.Anything,
).Return(&milvuspb.DescribeCollectionResponse{ ).Return(&milvuspb.DescribeCollectionResponse{
@ -114,3 +115,37 @@ func TestCoordinatorBroker_GetRecoveryInfo(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
}) })
} }
func TestCoordinatorBroker_GetPartitions(t *testing.T) {
collection := int64(100)
partitions := []int64{10, 11, 12}
t.Run("normal case", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: &commonpb.Status{},
PartitionIDs: partitions,
}, nil)
ctx := context.Background()
broker := &CoordinatorBroker{rootCoord: rc}
retPartitions, err := broker.GetPartitions(ctx, collection)
assert.NoError(t, err)
assert.ElementsMatch(t, partitions, retPartitions)
})
t.Run("collection not exist", func(t *testing.T) {
rc := mocks.NewRootCoord(t)
rc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_CollectionNotExists,
},
}, nil)
ctx := context.Background()
broker := &CoordinatorBroker{rootCoord: rc}
_, err := broker.GetPartitions(ctx, collection)
assert.ErrorIs(t, err, merr.ErrCollectionNotFound)
})
}

View File

@ -101,8 +101,7 @@ func (suite *ServerSuite) SetupSuite() {
func (suite *ServerSuite) SetupTest() { func (suite *ServerSuite) SetupTest() {
var err error var err error
suite.server, err = suite.newQueryCoord()
suite.server, err = newQueryCoord()
suite.Require().NoError(err) suite.Require().NoError(err)
suite.hackServer() suite.hackServer()
err = suite.server.Start() err = suite.server.Start()
@ -139,7 +138,7 @@ func (suite *ServerSuite) TestRecover() {
err := suite.server.Stop() err := suite.server.Stop()
suite.NoError(err) suite.NoError(err)
suite.server, err = newQueryCoord() suite.server, err = suite.newQueryCoord()
suite.NoError(err) suite.NoError(err)
suite.hackServer() suite.hackServer()
err = suite.server.Start() err = suite.server.Start()
@ -154,7 +153,7 @@ func (suite *ServerSuite) TestRecoverFailed() {
err := suite.server.Stop() err := suite.server.Stop()
suite.NoError(err) suite.NoError(err)
suite.server, err = newQueryCoord() suite.server, err = suite.newQueryCoord()
suite.NoError(err) suite.NoError(err)
broker := meta.NewMockBroker(suite.T()) broker := meta.NewMockBroker(suite.T())
@ -273,7 +272,7 @@ func (suite *ServerSuite) TestDisableActiveStandby() {
err := suite.server.Stop() err := suite.server.Stop()
suite.NoError(err) suite.NoError(err)
suite.server, err = newQueryCoord() suite.server, err = suite.newQueryCoord()
suite.NoError(err) suite.NoError(err)
suite.Equal(commonpb.StateCode_Initializing, suite.server.State()) suite.Equal(commonpb.StateCode_Initializing, suite.server.State())
suite.hackServer() suite.hackServer()
@ -294,7 +293,7 @@ func (suite *ServerSuite) TestEnableActiveStandby() {
err := suite.server.Stop() err := suite.server.Stop()
suite.NoError(err) suite.NoError(err)
suite.server, err = newQueryCoord() suite.server, err = suite.newQueryCoord()
suite.NoError(err) suite.NoError(err)
mockRootCoord := coordMocks.NewRootCoord(suite.T()) mockRootCoord := coordMocks.NewRootCoord(suite.T())
mockDataCoord := coordMocks.NewDataCoord(suite.T()) mockDataCoord := coordMocks.NewDataCoord(suite.T())
@ -310,7 +309,7 @@ func (suite *ServerSuite) TestEnableActiveStandby() {
), ),
CollectionID: collection, CollectionID: collection,
} }
mockRootCoord.EXPECT().ShowPartitionsInternal(mock.Anything, req).Return(&milvuspb.ShowPartitionsResponse{ mockRootCoord.EXPECT().ShowPartitions(mock.Anything, req).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil), Status: merr.Status(nil),
PartitionIDs: suite.partitions[collection], PartitionIDs: suite.partitions[collection],
}, nil).Maybe() }, nil).Maybe()
@ -530,7 +529,33 @@ func (suite *ServerSuite) hackServer() {
log.Debug("server hacked") log.Debug("server hacked")
} }
func newQueryCoord() (*Server, error) { func (suite *ServerSuite) hackBroker(server *Server) {
mockRootCoord := coordMocks.NewRootCoord(suite.T())
mockDataCoord := coordMocks.NewDataCoord(suite.T())
for _, collection := range suite.collections {
mockRootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
Status: merr.Status(nil),
Schema: &schemapb.CollectionSchema{},
}, nil).Maybe()
req := &milvuspb.ShowPartitionsRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
),
CollectionID: collection,
}
mockRootCoord.EXPECT().ShowPartitions(mock.Anything, req).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil),
PartitionIDs: suite.partitions[collection],
}, nil).Maybe()
}
err := server.SetRootCoord(mockRootCoord)
suite.NoError(err)
err = server.SetDataCoord(mockDataCoord)
suite.NoError(err)
}
func (suite *ServerSuite) newQueryCoord() (*Server, error) {
server, err := NewQueryCoord(context.Background()) server, err := NewQueryCoord(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
@ -549,6 +574,7 @@ func newQueryCoord() (*Server, error) {
} }
server.SetEtcdClient(etcdCli) server.SetEtcdClient(etcdCli)
server.SetQueryNodeCreator(session.DefaultQueryNodeCreator) server.SetQueryNodeCreator(session.DefaultQueryNodeCreator)
suite.hackBroker(server)
err = server.Init() err = server.Init()
return server, err return server, err
} }