Dynamic load/release partitions (#22655)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-03-20 14:55:57 +08:00 committed by GitHub
parent 7c633e9b9d
commit 1f718118e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
78 changed files with 3184 additions and 1574 deletions

View File

@ -272,6 +272,25 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
return ret.(*commonpb.Status), err return ret.(*commonpb.Status), err
} }
// SyncNewCreatedPartition notifies QueryCoord to sync new created partition if collection is loaded.
func (c *Client) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)),
)
ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.SyncNewCreatedPartition(ctx, req)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}
// GetPartitionStates gets the states of the specified partition. // GetPartitionStates gets the states of the specified partition.
func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
req = typeutil.Clone(req) req = typeutil.Clone(req)

View File

@ -114,6 +114,9 @@ func Test_NewClient(t *testing.T) {
r7, err := client.ReleasePartitions(ctx, nil) r7, err := client.ReleasePartitions(ctx, nil)
retCheck(retNotNil, r7, err) retCheck(retNotNil, r7, err)
r7, err = client.SyncNewCreatedPartition(ctx, nil)
retCheck(retNotNil, r7, err)
r8, err := client.ShowCollections(ctx, nil) r8, err := client.ShowCollections(ctx, nil)
retCheck(retNotNil, r8, err) retCheck(retNotNil, r8, err)

View File

@ -329,6 +329,11 @@ func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
return s.queryCoord.ReleasePartitions(ctx, req) return s.queryCoord.ReleasePartitions(ctx, req)
} }
// SyncNewCreatedPartition notifies QueryCoord to sync new created partition if collection is loaded.
func (s *Server) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
return s.queryCoord.SyncNewCreatedPartition(ctx, req)
}
// GetSegmentInfo gets the information of the specified segment from QueryCoord. // GetSegmentInfo gets the information of the specified segment from QueryCoord.
func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return s.queryCoord.GetSegmentInfo(ctx, req) return s.queryCoord.GetSegmentInfo(ctx, req)

View File

@ -213,6 +213,24 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
return ret.(*commonpb.Status), err return ret.(*commonpb.Status), err
} }
// LoadPartitions updates partitions meta info in QueryNode.
func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadPartitions(ctx, req)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}
// ReleasePartitions releases the data of the specified partitions in QueryNode. // ReleasePartitions releases the data of the specified partitions in QueryNode.
func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
req = typeutil.Clone(req) req = typeutil.Clone(req)

View File

@ -78,6 +78,9 @@ func Test_NewClient(t *testing.T) {
r8, err := client.ReleaseCollection(ctx, nil) r8, err := client.ReleaseCollection(ctx, nil)
retCheck(retNotNil, r8, err) retCheck(retNotNil, r8, err)
r8, err = client.LoadPartitions(ctx, nil)
retCheck(retNotNil, r8, err)
r9, err := client.ReleasePartitions(ctx, nil) r9, err := client.ReleasePartitions(ctx, nil)
retCheck(retNotNil, r9, err) retCheck(retNotNil, r9, err)

View File

@ -276,6 +276,11 @@ func (s *Server) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
return s.querynode.ReleaseCollection(ctx, req) return s.querynode.ReleaseCollection(ctx, req)
} }
// LoadPartitions updates partitions meta info in QueryNode.
func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return s.querynode.LoadPartitions(ctx, req)
}
// ReleasePartitions releases the data of the specified partitions in QueryNode. // ReleasePartitions releases the data of the specified partitions in QueryNode.
func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
// ignore ctx // ignore ctx

View File

@ -94,6 +94,10 @@ func (m *MockQueryNode) ReleaseCollection(ctx context.Context, req *querypb.Rele
return m.status, m.err return m.status, m.err
} }
func (m *MockQueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return m.status, m.err
}
func (m *MockQueryNode) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (m *MockQueryNode) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return m.status, m.err return m.status, m.err
} }
@ -263,6 +267,13 @@ func Test_NewServer(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
}) })
t.Run("LoadPartitions", func(t *testing.T) {
req := &querypb.LoadPartitionsRequest{}
resp, err := server.LoadPartitions(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
t.Run("ReleasePartitions", func(t *testing.T) { t.Run("ReleasePartitions", func(t *testing.T) {
req := &querypb.ReleasePartitionsRequest{} req := &querypb.ReleasePartitionsRequest{}
resp, err := server.ReleasePartitions(ctx, req) resp, err := server.ReleasePartitions(ctx, req)

View File

@ -150,13 +150,13 @@ type IndexCoordCatalog interface {
} }
type QueryCoordCatalog interface { type QueryCoordCatalog interface {
SaveCollection(info *querypb.CollectionLoadInfo) error SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error
SavePartition(info ...*querypb.PartitionLoadInfo) error SavePartition(info ...*querypb.PartitionLoadInfo) error
SaveReplica(replica *querypb.Replica) error SaveReplica(replica *querypb.Replica) error
GetCollections() ([]*querypb.CollectionLoadInfo, error) GetCollections() ([]*querypb.CollectionLoadInfo, error)
GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error)
GetReplicas() ([]*querypb.Replica, error) GetReplicas() ([]*querypb.Replica, error)
ReleaseCollection(id int64) error ReleaseCollection(collection int64) error
ReleasePartition(collection int64, partitions ...int64) error ReleasePartition(collection int64, partitions ...int64) error
ReleaseReplicas(collectionID int64) error ReleaseReplicas(collectionID int64) error
ReleaseReplica(collection, replica int64) error ReleaseReplica(collection, replica int64) error

View File

@ -31,12 +31,12 @@ var (
Help: "number of collections", Help: "number of collections",
}, []string{}) }, []string{})
QueryCoordNumEntities = prometheus.NewGaugeVec( QueryCoordNumPartitions = prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: milvusNamespace, Namespace: milvusNamespace,
Subsystem: typeutil.QueryCoordRole, Subsystem: typeutil.QueryCoordRole,
Name: "entity_num", Name: "partition_num",
Help: "number of entities", Help: "number of partitions",
}, []string{}) }, []string{})
QueryCoordLoadCount = prometheus.NewCounterVec( QueryCoordLoadCount = prometheus.NewCounterVec(
@ -97,7 +97,7 @@ var (
// RegisterQueryCoord registers QueryCoord metrics // RegisterQueryCoord registers QueryCoord metrics
func RegisterQueryCoord(registry *prometheus.Registry) { func RegisterQueryCoord(registry *prometheus.Registry) {
registry.MustRegister(QueryCoordNumCollections) registry.MustRegister(QueryCoordNumCollections)
registry.MustRegister(QueryCoordNumEntities) registry.MustRegister(QueryCoordNumPartitions)
registry.MustRegister(QueryCoordLoadCount) registry.MustRegister(QueryCoordLoadCount)
registry.MustRegister(QueryCoordReleaseCount) registry.MustRegister(QueryCoordReleaseCount)
registry.MustRegister(QueryCoordLoadLatency) registry.MustRegister(QueryCoordLoadLatency)

View File

@ -23,6 +23,7 @@ service QueryCoord {
rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {} rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {}
rpc LoadCollection(LoadCollectionRequest) returns (common.Status) {} rpc LoadCollection(LoadCollectionRequest) returns (common.Status) {}
rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {} rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {}
rpc SyncNewCreatedPartition(SyncNewCreatedPartitionRequest) returns (common.Status) {}
rpc GetPartitionStates(GetPartitionStatesRequest) returns (GetPartitionStatesResponse) {} rpc GetPartitionStates(GetPartitionStatesRequest) returns (GetPartitionStatesResponse) {}
rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {} rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {}
@ -55,6 +56,7 @@ service QueryNode {
rpc UnsubDmChannel(UnsubDmChannelRequest) returns (common.Status) {} rpc UnsubDmChannel(UnsubDmChannelRequest) returns (common.Status) {}
rpc LoadSegments(LoadSegmentsRequest) returns (common.Status) {} rpc LoadSegments(LoadSegmentsRequest) returns (common.Status) {}
rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {} rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {}
rpc LoadPartitions(LoadPartitionsRequest) returns (common.Status) {}
rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {} rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {}
rpc ReleaseSegments(ReleaseSegmentsRequest) returns (common.Status) {} rpc ReleaseSegments(ReleaseSegmentsRequest) returns (common.Status) {}
rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {} rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {}
@ -189,6 +191,12 @@ message ShardLeadersList { // All leaders of all replicas of one shard
repeated string node_addrs = 3; repeated string node_addrs = 3;
} }
message SyncNewCreatedPartitionRequest {
common.MsgBase base = 1;
int64 collectionID = 2;
int64 partitionID = 3;
}
//-----------------query node grpc request and response proto---------------- //-----------------query node grpc request and response proto----------------
message LoadMetaInfo { message LoadMetaInfo {
LoadType load_type = 1; LoadType load_type = 1;
@ -482,18 +490,19 @@ enum LoadStatus {
message CollectionLoadInfo { message CollectionLoadInfo {
int64 collectionID = 1; int64 collectionID = 1;
repeated int64 released_partitions = 2; repeated int64 released_partitions = 2; // Deprecated: No longer used; kept for compatibility.
int32 replica_number = 3; int32 replica_number = 3;
LoadStatus status = 4; LoadStatus status = 4;
map<int64, int64> field_indexID = 5; map<int64, int64> field_indexID = 5;
LoadType load_type = 6;
} }
message PartitionLoadInfo { message PartitionLoadInfo {
int64 collectionID = 1; int64 collectionID = 1;
int64 partitionID = 2; int64 partitionID = 2;
int32 replica_number = 3; int32 replica_number = 3; // Deprecated: No longer used; kept for compatibility.
LoadStatus status = 4; LoadStatus status = 4;
map<int64, int64> field_indexID = 5; map<int64, int64> field_indexID = 5; // Deprecated: No longer used; kept for compatibility.
} }
message Replica { message Replica {

File diff suppressed because it is too large Load Diff

View File

@ -1980,7 +1980,7 @@ func TestProxy(t *testing.T) {
Type: milvuspb.ShowType_InMemory, Type: milvuspb.ShowType_InMemory,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// default partition // default partition
assert.Equal(t, 0, len(resp.PartitionNames)) assert.Equal(t, 0, len(resp.PartitionNames))

View File

@ -85,6 +85,10 @@ func (m *QueryNodeMock) ReleaseCollection(ctx context.Context, req *querypb.Rele
return nil, nil return nil, nil
} }
func (m *QueryNodeMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO // TODO
func (m *QueryNodeMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (m *QueryNodeMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return nil, nil return nil, nil

View File

@ -868,32 +868,6 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error {
return err return err
} }
collID, err := globalMetaCache.GetCollectionID(ctx, dpt.GetCollectionName())
if err != nil {
return err
}
partID, err := globalMetaCache.GetPartitionID(ctx, dpt.GetCollectionName(), dpt.GetPartitionName())
if err != nil {
if err.Error() == ErrPartitionNotExist(dpt.GetPartitionName()).Error() {
return nil
}
return err
}
collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, collID)
if err != nil {
return err
}
if collLoaded {
loaded, err := isPartitionLoaded(ctx, dpt.queryCoord, collID, []int64{partID})
if err != nil {
return err
}
if loaded {
return errors.New("partition cannot be dropped, partition is loaded, please release it first")
}
}
return nil return nil
} }
@ -1587,6 +1561,9 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error {
} }
partitionIDs = append(partitionIDs, partitionID) partitionIDs = append(partitionIDs, partitionID)
} }
if len(partitionIDs) == 0 {
return errors.New("failed to load partition, due to no partition specified")
}
request := &querypb.LoadPartitionsRequest{ request := &querypb.LoadPartitionsRequest{
Base: commonpbutil.UpdateMsgBase( Base: commonpbutil.UpdateMsgBase(
lpt.Base, lpt.Base,

View File

@ -1134,20 +1134,6 @@ func TestDropPartitionTask(t *testing.T) {
err = task.PreExecute(ctx) err = task.PreExecute(ctx)
assert.NotNil(t, err) assert.NotNil(t, err)
t.Run("get collectionID error", func(t *testing.T) {
mockCache := newMockCache()
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) {
return 1, nil
})
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 0, errors.New("error")
})
globalMetaCache = mockCache
task.PartitionName = "partition1"
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("partition not exist", func(t *testing.T) { t.Run("partition not exist", func(t *testing.T) {
task.PartitionName = "partition2" task.PartitionName = "partition2"
@ -1162,21 +1148,6 @@ func TestDropPartitionTask(t *testing.T) {
err = task.PreExecute(ctx) err = task.PreExecute(ctx)
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("get partition error", func(t *testing.T) {
task.PartitionName = "partition3"
mockCache := newMockCache()
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) {
return 0, errors.New("error")
})
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 1, nil
})
globalMetaCache = mockCache
err = task.PreExecute(ctx)
assert.Error(t, err)
})
} }
func TestHasPartitionTask(t *testing.T) { func TestHasPartitionTask(t *testing.T) {

View File

@ -19,13 +19,14 @@ package balance
import ( import (
"sort" "sort"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/samber/lo"
"go.uber.org/zap"
) )
type RowCountBasedBalancer struct { type RowCountBasedBalancer struct {

View File

@ -69,6 +69,8 @@ func (suite *RowCountBasedBalancerTestSuite) SetupTest() {
distManager := meta.NewDistributionManager() distManager := meta.NewDistributionManager()
suite.mockScheduler = task.NewMockScheduler(suite.T()) suite.mockScheduler = task.NewMockScheduler(suite.T())
suite.balancer = NewRowCountBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget) suite.balancer = NewRowCountBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe()
} }
func (suite *RowCountBasedBalancerTestSuite) TearDownTest() { func (suite *RowCountBasedBalancerTestSuite) TearDownTest() {
@ -257,6 +259,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1) balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
for node, s := range c.distributions { for node, s := range c.distributions {
@ -359,6 +362,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() {
balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1) balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
for node, s := range c.distributions { for node, s := range c.distributions {
@ -415,6 +419,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() {
collection := utils.CreateTestCollection(1, 1) collection := utils.CreateTestCollection(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loading collection.Status = querypb.LoadStatus_Loading
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes)) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes))
for node, s := range c.distributions { for node, s := range c.distributions {

View File

@ -74,6 +74,8 @@ func (suite *ChannelCheckerTestSuite) SetupTest() {
balancer := suite.createMockBalancer() balancer := suite.createMockBalancer()
suite.checker = NewChannelChecker(suite.meta, distManager, targetManager, balancer) suite.checker = NewChannelChecker(suite.meta, distManager, targetManager, balancer)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe()
} }
func (suite *ChannelCheckerTestSuite) TearDownTest() { func (suite *ChannelCheckerTestSuite) TearDownTest() {

View File

@ -74,6 +74,8 @@ func (suite *SegmentCheckerTestSuite) SetupTest() {
balancer := suite.createMockBalancer() balancer := suite.createMockBalancer()
suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer) suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe()
} }
func (suite *SegmentCheckerTestSuite) TearDownTest() { func (suite *SegmentCheckerTestSuite) TearDownTest() {

View File

@ -26,4 +26,5 @@ var (
ErrCollectionLoaded = errors.New("CollectionLoaded") ErrCollectionLoaded = errors.New("CollectionLoaded")
ErrLoadParameterMismatched = errors.New("LoadParameterMismatched") ErrLoadParameterMismatched = errors.New("LoadParameterMismatched")
ErrNoEnoughNode = errors.New("NoEnoughNode") ErrNoEnoughNode = errors.New("NoEnoughNode")
ErrPartitionNotInTarget = errors.New("PartitionNotInLoadingTarget")
) )

View File

@ -18,20 +18,6 @@ package job
import ( import (
"context" "context"
"fmt"
"time"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
// Job is request of loading/releasing collection/partitions, // Job is request of loading/releasing collection/partitions,
@ -106,439 +92,3 @@ func (job *BaseJob) PreExecute() error {
} }
func (job *BaseJob) PostExecute() {} func (job *BaseJob) PostExecute() {}
type LoadCollectionJob struct {
*BaseJob
req *querypb.LoadCollectionRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadCollectionJob(
ctx context.Context,
req *querypb.LoadCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadCollectionJob {
return &LoadCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadCollectionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
if job.meta.Exist(req.GetCollectionID()) {
old := job.meta.GetCollection(req.GetCollectionID())
if old == nil {
msg := "load the partition after load collection is not supported"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if old.GetReplicaNumber() != req.GetReplicaNumber() {
msg := fmt.Sprintf("collection with different replica number %d existed, release this collection first before changing its replica number",
job.meta.GetReplicaNumber(req.GetCollectionID()),
)
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(old.GetFieldIndexID(), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
old.GetFieldIndexID())
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return ErrCollectionLoaded
}
return nil
}
func (job *LoadCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// Clear stale replicas
err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to clear stale replicas", zap.Error(err))
return err
}
// Create replicas
replicas, err := utils.SpawnReplicasWithRG(job.meta,
req.GetCollectionID(),
req.GetResourceGroups(),
req.GetReplicaNumber(),
)
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created",
zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()),
zap.String("resourceGroup", replica.GetResourceGroup()))
}
// Fetch channels and segments from DataCoord
partitionIDs, err := job.broker.GetPartitions(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to get partitions from RootCoord"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
// It's safe here to call UpdateCollectionNextTargetWithPartitions, as the collection not existing
err = job.targetMgr.UpdateCollectionNextTargetWithPartitions(req.GetCollectionID(), partitionIDs...)
if err != nil {
msg := "failed to update next targets for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.CollectionManager.PutCollection(&meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
if err != nil {
msg := "failed to store collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
return nil
}
func (job *LoadCollectionJob) PostExecute() {
if job.Error() != nil && !job.meta.Exist(job.CollectionID()) {
job.meta.ReplicaManager.RemoveCollection(job.CollectionID())
job.targetMgr.RemoveCollection(job.req.GetCollectionID())
}
}
type ReleaseCollectionJob struct {
*BaseJob
req *querypb.ReleaseCollectionRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleaseCollectionJob(ctx context.Context,
req *querypb.ReleaseCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleaseCollectionJob {
return &ReleaseCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleaseCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove collection"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove replicas"
log.Warn(msg, zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
waitCollectionReleased(job.dist, req.GetCollectionID())
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
return nil
}
type LoadPartitionJob struct {
*BaseJob
req *querypb.LoadPartitionsRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadPartitionJob(
ctx context.Context,
req *querypb.LoadPartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadPartitionJob {
return &LoadPartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadPartitionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
if job.meta.Exist(req.GetCollectionID()) {
old := job.meta.GetCollection(req.GetCollectionID())
if old != nil {
msg := "load the partition after load collection is not supported"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if job.meta.GetReplicaNumber(req.GetCollectionID()) != req.GetReplicaNumber() {
msg := "collection with different replica number existed, release this collection first before changing its replica number"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(job.meta.GetFieldIndex(req.GetCollectionID()), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
job.meta.GetFieldIndex(req.GetCollectionID()))
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
// Check whether one of the given partitions not loaded
for _, partitionID := range req.GetPartitionIDs() {
partition := job.meta.GetPartition(partitionID)
if partition == nil {
msg := fmt.Sprintf("some partitions %v of collection %v has been loaded into QueryNode, please release partitions firstly",
req.GetPartitionIDs(),
req.GetCollectionID())
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
}
return ErrCollectionLoaded
}
return nil
}
func (job *LoadPartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// Clear stale replicas
err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to clear stale replicas", zap.Error(err))
return err
}
// Create replicas
replicas, err := utils.SpawnReplicasWithRG(job.meta,
req.GetCollectionID(),
req.GetResourceGroups(),
req.GetReplicaNumber(),
)
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created",
zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()),
zap.String("resourceGroup", replica.GetResourceGroup()))
}
// It's safe here to call UpdateCollectionNextTargetWithPartitions, as the collection not existing
err = job.targetMgr.UpdateCollectionNextTargetWithPartitions(req.GetCollectionID(), req.GetPartitionIDs()...)
if err != nil {
msg := "failed to update next targets for collection"
log.Error(msg,
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
zap.Error(err))
return utils.WrapError(msg, err)
}
partitions := lo.Map(req.GetPartitionIDs(), func(partition int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: partition,
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
},
CreatedAt: time.Now(),
}
})
err = job.meta.CollectionManager.PutPartition(partitions...)
if err != nil {
msg := "failed to store partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
return nil
}
func (job *LoadPartitionJob) PostExecute() {
if job.Error() != nil && !job.meta.Exist(job.CollectionID()) {
job.meta.ReplicaManager.RemoveCollection(job.CollectionID())
job.targetMgr.RemoveCollection(job.req.GetCollectionID())
}
}
type ReleasePartitionJob struct {
*BaseJob
req *querypb.ReleasePartitionsRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleasePartitionJob(ctx context.Context,
req *querypb.ReleasePartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleasePartitionJob {
return &ReleasePartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleasePartitionJob) PreExecute() error {
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", job.req.GetCollectionID()),
)
if job.meta.CollectionManager.GetLoadType(job.req.GetCollectionID()) == querypb.LoadType_LoadCollection {
msg := "releasing some partitions after load collection is not supported"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return nil
}
func (job *ReleasePartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID())
partitionIDs := typeutil.NewUniqueSet(req.GetPartitionIDs()...)
toRelease := make([]int64, 0)
for _, partition := range loadedPartitions {
if partitionIDs.Contain(partition.GetPartitionID()) {
toRelease = append(toRelease, partition.GetPartitionID())
}
}
if len(toRelease) == len(loadedPartitions) { // All partitions are released, clear all
log.Info("release partitions covers all partitions, will remove the whole collection")
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to remove replicas", zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
waitCollectionReleased(job.dist, req.GetCollectionID())
} else {
err := job.meta.CollectionManager.RemovePartition(toRelease...)
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.targetMgr.RemovePartition(req.GetCollectionID(), toRelease...)
waitCollectionReleased(job.dist, req.GetCollectionID(), toRelease...)
}
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
return nil
}

View File

@ -0,0 +1,397 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"fmt"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type LoadCollectionJob struct {
*BaseJob
req *querypb.LoadCollectionRequest
undo *UndoList
dist *meta.DistributionManager
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadCollectionJob(
ctx context.Context,
req *querypb.LoadCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
cluster session.Cluster,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadCollectionJob {
return &LoadCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
undo: NewUndoList(ctx, meta, cluster, targetMgr, targetObserver),
dist: dist,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadCollectionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
collection := job.meta.GetCollection(req.GetCollectionID())
if collection == nil {
return nil
}
if collection.GetReplicaNumber() != req.GetReplicaNumber() {
msg := fmt.Sprintf("collection with different replica number %d existed, release this collection first before changing its replica number",
job.meta.GetReplicaNumber(req.GetCollectionID()),
)
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(collection.GetFieldIndexID(), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
collection.GetFieldIndexID())
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return nil
}
func (job *LoadCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// 1. Fetch target partitions
partitionIDs, err := job.broker.GetPartitions(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to get partitions from RootCoord"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
lackPartitionIDs := lo.FilterMap(partitionIDs, func(partID int64, _ int) (int64, bool) {
return partID, !lo.Contains(loadedPartitionIDs, partID)
})
if len(lackPartitionIDs) == 0 {
return ErrCollectionLoaded
}
job.undo.CollectionID = req.GetCollectionID()
job.undo.LackPartitions = lackPartitionIDs
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
// 2. loadPartitions on QueryNodes
err = loadPartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), lackPartitionIDs...)
if err != nil {
return err
}
job.undo.PartitionsLoaded = true
// 3. update next target
_, err = job.targetObserver.UpdateNextTarget(req.GetCollectionID(), partitionIDs...)
if err != nil {
msg := "failed to update next target"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.undo.TargetUpdated = true
colExisted := job.meta.CollectionManager.Exist(req.GetCollectionID())
if !colExisted {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
}
// 4. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID())
if len(replicas) == 0 {
replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber())
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created", zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()), zap.String("resourceGroup", replica.GetResourceGroup()))
}
job.undo.NewReplicaCreated = true
}
// 5. put collection/partitions meta
partitions := lo.Map(lackPartitionIDs, func(partID int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: partID,
Status: querypb.LoadStatus_Loading,
},
CreatedAt: time.Now(),
}
})
collection := &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
LoadType: querypb.LoadType_LoadCollection,
},
CreatedAt: time.Now(),
}
err = job.meta.CollectionManager.PutCollection(collection, partitions...)
if err != nil {
msg := "failed to store collection and partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
if !colExisted {
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
}
metrics.QueryCoordNumPartitions.WithLabelValues().Add(float64(len(partitions)))
return nil
}
func (job *LoadCollectionJob) PostExecute() {
if job.Error() != nil && !errors.Is(job.Error(), ErrCollectionLoaded) {
job.undo.RollBack()
}
}
type LoadPartitionJob struct {
*BaseJob
req *querypb.LoadPartitionsRequest
undo *UndoList
dist *meta.DistributionManager
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadPartitionJob(
ctx context.Context,
req *querypb.LoadPartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
cluster session.Cluster,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadPartitionJob {
return &LoadPartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
undo: NewUndoList(ctx, meta, cluster, targetMgr, targetObserver),
dist: dist,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadPartitionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
collection := job.meta.GetCollection(req.GetCollectionID())
if collection == nil {
return nil
}
if collection.GetReplicaNumber() != req.GetReplicaNumber() {
msg := "collection with different replica number existed, release this collection first before changing its replica number"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(collection.GetFieldIndexID(), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
job.meta.GetFieldIndex(req.GetCollectionID()))
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return nil
}
func (job *LoadPartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// 1. Fetch target partitions
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
lackPartitionIDs := lo.FilterMap(req.GetPartitionIDs(), func(partID int64, _ int) (int64, bool) {
return partID, !lo.Contains(loadedPartitionIDs, partID)
})
if len(lackPartitionIDs) == 0 {
return ErrCollectionLoaded
}
job.undo.CollectionID = req.GetCollectionID()
job.undo.LackPartitions = lackPartitionIDs
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
// 2. loadPartitions on QueryNodes
err := loadPartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), lackPartitionIDs...)
if err != nil {
return err
}
job.undo.PartitionsLoaded = true
// 3. update next target
_, err = job.targetObserver.UpdateNextTarget(req.GetCollectionID(), append(loadedPartitionIDs, lackPartitionIDs...)...)
if err != nil {
msg := "failed to update next target"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.undo.TargetUpdated = true
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
}
// 4. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID())
if len(replicas) == 0 {
replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber())
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created", zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()), zap.String("resourceGroup", replica.GetResourceGroup()))
}
job.undo.NewReplicaCreated = true
}
// 5. put collection/partitions meta
partitions := lo.Map(lackPartitionIDs, func(partID int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: partID,
Status: querypb.LoadStatus_Loading,
},
CreatedAt: time.Now(),
}
})
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
collection := &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
LoadType: querypb.LoadType_LoadPartition,
},
CreatedAt: time.Now(),
}
err = job.meta.CollectionManager.PutCollection(collection, partitions...)
if err != nil {
msg := "failed to store collection and partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
} else { // collection exists, put partitions only
err = job.meta.CollectionManager.PutPartition(partitions...)
if err != nil {
msg := "failed to store partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
}
metrics.QueryCoordNumPartitions.WithLabelValues().Add(float64(len(partitions)))
return nil
}
func (job *LoadPartitionJob) PostExecute() {
if job.Error() != nil && !errors.Is(job.Error(), ErrCollectionLoaded) {
job.undo.RollBack()
}
}

View File

@ -0,0 +1,176 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
)
type ReleaseCollectionJob struct {
*BaseJob
req *querypb.ReleaseCollectionRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleaseCollectionJob(ctx context.Context,
req *querypb.ReleaseCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleaseCollectionJob {
return &ReleaseCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleaseCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
lenPartitions := len(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()))
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove collection"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove replicas"
log.Warn(msg, zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
waitCollectionReleased(job.dist, req.GetCollectionID())
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
metrics.QueryCoordNumPartitions.WithLabelValues().Sub(float64(lenPartitions))
return nil
}
type ReleasePartitionJob struct {
*BaseJob
releasePartitionsOnly bool
req *querypb.ReleasePartitionsRequest
dist *meta.DistributionManager
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleasePartitionJob(ctx context.Context,
req *querypb.ReleasePartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
cluster session.Cluster,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleasePartitionJob {
return &ReleasePartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleasePartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID())
toRelease := lo.FilterMap(loadedPartitions, func(partition *meta.Partition, _ int) (int64, bool) {
return partition.GetPartitionID(), lo.Contains(req.GetPartitionIDs(), partition.GetPartitionID())
})
// If all partitions are released and LoadType is LoadPartition, clear all
if len(toRelease) == len(loadedPartitions) &&
job.meta.GetLoadType(req.GetCollectionID()) == querypb.LoadType_LoadPartition {
log.Info("release partitions covers all partitions, will remove the whole collection")
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to remove replicas", zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
waitCollectionReleased(job.dist, req.GetCollectionID())
} else {
err := releasePartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), toRelease...)
if err != nil {
loadPartitions(job.ctx, job.meta, job.cluster, true, req.GetCollectionID(), toRelease...)
return err
}
err = job.meta.CollectionManager.RemovePartition(toRelease...)
if err != nil {
loadPartitions(job.ctx, job.meta, job.cluster, true, req.GetCollectionID(), toRelease...)
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.targetMgr.RemovePartition(req.GetCollectionID(), toRelease...)
waitCollectionReleased(job.dist, req.GetCollectionID(), toRelease...)
}
metrics.QueryCoordNumPartitions.WithLabelValues().Sub(float64(len(toRelease)))
return nil
}

View File

@ -0,0 +1,103 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
)
type SyncNewCreatedPartitionJob struct {
*BaseJob
req *querypb.SyncNewCreatedPartitionRequest
meta *meta.Meta
cluster session.Cluster
}
func NewSyncNewCreatedPartitionJob(
ctx context.Context,
req *querypb.SyncNewCreatedPartitionRequest,
meta *meta.Meta,
cluster session.Cluster,
) *SyncNewCreatedPartitionJob {
return &SyncNewCreatedPartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
meta: meta,
cluster: cluster,
}
}
func (job *SyncNewCreatedPartitionJob) PreExecute() error {
// check if collection not load or loadType is loadPartition
collection := job.meta.GetCollection(job.req.GetCollectionID())
if collection == nil || collection.GetLoadType() == querypb.LoadType_LoadPartition {
return ErrPartitionNotInTarget
}
// check if partition already existed
if partition := job.meta.GetPartition(job.req.GetPartitionID()); partition != nil {
return ErrPartitionNotInTarget
}
return nil
}
func (job *SyncNewCreatedPartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64("partitionID", req.GetPartitionID()),
)
err := loadPartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), req.GetPartitionID())
if err != nil {
return err
}
partition := &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: req.GetPartitionID(),
Status: querypb.LoadStatus_Loaded,
},
LoadPercentage: 100,
CreatedAt: time.Now(),
}
err = job.meta.CollectionManager.PutPartition(partition)
if err != nil {
msg := "failed to store partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
return nil
}
func (job *SyncNewCreatedPartitionJob) PostExecute() {
if job.Error() != nil && !errors.Is(job.Error(), ErrPartitionNotInTarget) {
releasePartitions(job.ctx, job.meta, job.cluster, true, job.req.GetCollectionID(), job.req.GetPartitionID())
}
}

View File

@ -21,10 +21,11 @@ import (
"testing" "testing"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
@ -33,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/observers" "github.com/milvus-io/milvus/internal/querycoordv2/observers"
. "github.com/milvus-io/milvus/internal/querycoordv2/params" . "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
) )
@ -56,6 +58,7 @@ type JobSuite struct {
store meta.Store store meta.Store
dist *meta.DistributionManager dist *meta.DistributionManager
meta *meta.Meta meta *meta.Meta
cluster *session.MockCluster
targetMgr *meta.TargetManager targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver targetObserver *observers.TargetObserver
broker *meta.MockBroker broker *meta.MockBroker
@ -70,8 +73,8 @@ func (suite *JobSuite) SetupSuite() {
suite.collections = []int64{1000, 1001} suite.collections = []int64{1000, 1001}
suite.partitions = map[int64][]int64{ suite.partitions = map[int64][]int64{
1000: {100, 101}, 1000: {100, 101, 102},
1001: {102, 103}, 1001: {103, 104, 105},
} }
suite.channels = map[int64][]string{ suite.channels = map[int64][]string{
1000: {"1000-dmc0", "1000-dmc1"}, 1000: {"1000-dmc0", "1000-dmc1"},
@ -81,10 +84,12 @@ func (suite *JobSuite) SetupSuite() {
1000: { 1000: {
100: {1, 2}, 100: {1, 2},
101: {3, 4}, 101: {3, 4},
102: {5, 6},
}, },
1001: { 1001: {
102: {5, 6},
103: {7, 8}, 103: {7, 8},
104: {9, 10},
105: {11, 12},
}, },
} }
suite.loadTypes = map[int64]querypb.LoadType{ suite.loadTypes = map[int64]querypb.LoadType{
@ -115,6 +120,14 @@ func (suite *JobSuite) SetupSuite() {
Return(vChannels, segmentBinlogs, nil) Return(vChannels, segmentBinlogs, nil)
} }
} }
suite.cluster = session.NewMockCluster(suite.T())
suite.cluster.EXPECT().
LoadPartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
suite.cluster.EXPECT().
ReleasePartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
} }
func (suite *JobSuite) SetupTest() { func (suite *JobSuite) SetupTest() {
@ -140,6 +153,7 @@ func (suite *JobSuite) SetupTest() {
suite.dist, suite.dist,
suite.broker, suite.broker,
) )
suite.targetObserver.Start(context.Background())
suite.scheduler = NewScheduler() suite.scheduler = NewScheduler()
suite.scheduler.Start(context.Background()) suite.scheduler.Start(context.Background())
@ -160,21 +174,16 @@ func (suite *JobSuite) SetupTest() {
func (suite *JobSuite) TearDownTest() { func (suite *JobSuite) TearDownTest() {
suite.kv.Close() suite.kv.Close()
suite.scheduler.Stop() suite.scheduler.Stop()
suite.targetObserver.Stop()
} }
func (suite *JobSuite) BeforeTest(suiteName, testName string) { func (suite *JobSuite) BeforeTest(suiteName, testName string) {
switch testName {
case "TestLoadCollection":
for collection, partitions := range suite.partitions { for collection, partitions := range suite.partitions {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
suite.broker.EXPECT(). suite.broker.EXPECT().
GetPartitions(mock.Anything, collection). GetPartitions(mock.Anything, collection).
Return(partitions, nil) Return(partitions, nil)
} }
} }
}
func (suite *JobSuite) TestLoadCollection() { func (suite *JobSuite) TestLoadCollection() {
ctx := context.Background() ctx := context.Background()
@ -195,7 +204,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -204,7 +215,7 @@ func (suite *JobSuite) TestLoadCollection() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection) suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load again // Test load again
@ -220,7 +231,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -243,7 +256,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -268,13 +283,15 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.ErrorIs(err, ErrCollectionLoaded)
} }
suite.meta.ResourceManager.AddResourceGroup("rg1") suite.meta.ResourceManager.AddResourceGroup("rg1")
@ -292,7 +309,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -302,7 +321,7 @@ func (suite *JobSuite) TestLoadCollection() {
// Load with 3 replica on 3 rg // Load with 3 replica on 3 rg
req = &querypb.LoadCollectionRequest{ req = &querypb.LoadCollectionRequest{
CollectionID: 1002, CollectionID: 1001,
ReplicaNumber: 3, ReplicaNumber: 3,
ResourceGroups: []string{"rg1", "rg2", "rg3"}, ResourceGroups: []string{"rg1", "rg2", "rg3"},
} }
@ -311,7 +330,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -338,7 +359,9 @@ func (suite *JobSuite) TestLoadCollectionWithReplicas() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -368,7 +391,9 @@ func (suite *JobSuite) TestLoadCollectionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -377,7 +402,7 @@ func (suite *JobSuite) TestLoadCollectionWithDiffIndex() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...) suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load with different index // Test load with different index
@ -396,7 +421,9 @@ func (suite *JobSuite) TestLoadCollectionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -425,7 +452,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -434,7 +463,7 @@ func (suite *JobSuite) TestLoadPartition() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...) suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load partition again // Test load partition again
@ -453,7 +482,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -478,7 +509,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -496,20 +529,22 @@ func (suite *JobSuite) TestLoadPartition() {
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
CollectionID: collection, CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 200), PartitionIDs: append(suite.partitions[collection], 200),
ReplicaNumber: 3, ReplicaNumber: 1,
} }
job := NewLoadPartitionJob( job := NewLoadPartitionJob(
ctx, ctx,
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.NoError(err)
} }
// Test load collection while partitions exists // Test load collection while partitions exists
@ -527,13 +562,15 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.ErrorIs(err, ErrCollectionLoaded)
} }
suite.meta.ResourceManager.AddResourceGroup("rg1") suite.meta.ResourceManager.AddResourceGroup("rg1")
@ -541,9 +578,11 @@ func (suite *JobSuite) TestLoadPartition() {
suite.meta.ResourceManager.AddResourceGroup("rg3") suite.meta.ResourceManager.AddResourceGroup("rg3")
// test load 3 replica in 1 rg, should pass rg check // test load 3 replica in 1 rg, should pass rg check
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(999)).Return([]int64{888}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(999), int64(888)).Return(nil, nil, nil)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
CollectionID: 100, CollectionID: 999,
PartitionIDs: []int64{1001}, PartitionIDs: []int64{888},
ReplicaNumber: 3, ReplicaNumber: 3,
ResourceGroups: []string{"rg1"}, ResourceGroups: []string{"rg1"},
} }
@ -552,7 +591,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -561,9 +602,11 @@ func (suite *JobSuite) TestLoadPartition() {
suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error()) suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error())
// test load 3 replica in 3 rg, should pass rg check // test load 3 replica in 3 rg, should pass rg check
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(999)).Return([]int64{888}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(999), int64(888)).Return(nil, nil, nil)
req = &querypb.LoadPartitionsRequest{ req = &querypb.LoadPartitionsRequest{
CollectionID: 102, CollectionID: 999,
PartitionIDs: []int64{1001}, PartitionIDs: []int64{888},
ReplicaNumber: 3, ReplicaNumber: 3,
ResourceGroups: []string{"rg1", "rg2", "rg3"}, ResourceGroups: []string{"rg1", "rg2", "rg3"},
} }
@ -572,7 +615,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -581,6 +626,120 @@ func (suite *JobSuite) TestLoadPartition() {
suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error()) suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error())
} }
func (suite *JobSuite) TestDynamicLoad() {
ctx := context.Background()
collection := suite.collections[0]
p0, p1, p2 := suite.partitions[collection][0], suite.partitions[collection][1], suite.partitions[collection][2]
newLoadPartJob := func(partitions ...int64) *LoadPartitionJob {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: partitions,
ReplicaNumber: 1,
}
job := NewLoadPartitionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.broker,
suite.nodeMgr,
)
return job
}
newLoadColJob := func() *LoadCollectionJob {
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: 1,
}
job := NewLoadCollectionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.broker,
suite.nodeMgr,
)
return job
}
// loaded: none
// action: load p0, p1, p2
// expect: p0, p1, p2 loaded
job := newLoadPartJob(p0, p1, p2)
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1, p2)
// loaded: p0, p1, p2
// action: load p0, p1, p2
// expect: do nothing, p0, p1, p2 loaded
job = newLoadPartJob(p0, p1, p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.ErrorIs(err, ErrCollectionLoaded)
suite.assertPartitionLoaded(collection)
// loaded: p0, p1
// action: load p2
// expect: p0, p1, p2 loaded
suite.releaseAll()
job = newLoadPartJob(p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1)
job = newLoadPartJob(p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p2)
// loaded: p0, p1
// action: load p1, p2
// expect: p0, p1, p2 loaded
suite.releaseAll()
job = newLoadPartJob(p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1)
job = newLoadPartJob(p1, p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p2)
// loaded: p0, p1
// action: load col
// expect: col loaded
suite.releaseAll()
job = newLoadPartJob(p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1)
colJob := newLoadColJob()
suite.scheduler.Add(colJob)
err = colJob.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p2)
}
func (suite *JobSuite) TestLoadPartitionWithReplicas() { func (suite *JobSuite) TestLoadPartitionWithReplicas() {
ctx := context.Background() ctx := context.Background()
@ -600,7 +759,9 @@ func (suite *JobSuite) TestLoadPartitionWithReplicas() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -631,7 +792,9 @@ func (suite *JobSuite) TestLoadPartitionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -640,7 +803,7 @@ func (suite *JobSuite) TestLoadPartitionWithDiffIndex() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...) suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load partition with different index // Test load partition with different index
@ -661,7 +824,9 @@ func (suite *JobSuite) TestLoadPartitionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -692,7 +857,7 @@ func (suite *JobSuite) TestReleaseCollection() {
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertCollectionReleased(collection)
} }
// Test release again // Test release again
@ -711,7 +876,7 @@ func (suite *JobSuite) TestReleaseCollection() {
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertCollectionReleased(collection)
} }
} }
@ -731,18 +896,14 @@ func (suite *JobSuite) TestReleasePartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver, suite.targetObserver,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.ErrorIs(err, ErrLoadParameterMismatched)
suite.assertLoaded(collection)
} else {
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertPartitionReleased(collection, suite.partitions[collection]...)
}
} }
// Test release again // Test release again
@ -756,18 +917,14 @@ func (suite *JobSuite) TestReleasePartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver, suite.targetObserver,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.ErrorIs(err, ErrLoadParameterMismatched)
suite.assertLoaded(collection)
} else {
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertPartitionReleased(collection, suite.partitions[collection]...)
}
} }
// Test release partial partitions // Test release partial partitions
@ -783,22 +940,112 @@ func (suite *JobSuite) TestReleasePartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver, suite.targetObserver,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.ErrorIs(err, ErrLoadParameterMismatched)
suite.assertLoaded(collection)
} else {
suite.NoError(err) suite.NoError(err)
suite.True(suite.meta.Exist(collection)) suite.True(suite.meta.Exist(collection))
partitions := suite.meta.GetPartitionsByCollection(collection) partitions := suite.meta.GetPartitionsByCollection(collection)
suite.Len(partitions, 1) suite.Len(partitions, 1)
suite.Equal(suite.partitions[collection][0], partitions[0].GetPartitionID()) suite.Equal(suite.partitions[collection][0], partitions[0].GetPartitionID())
suite.assertPartitionReleased(collection, suite.partitions[collection][1:]...)
} }
} }
func (suite *JobSuite) TestDynamicRelease() {
ctx := context.Background()
col0, col1 := suite.collections[0], suite.collections[1]
p0, p1, p2 := suite.partitions[col0][0], suite.partitions[col0][1], suite.partitions[col0][2]
p3, p4, p5 := suite.partitions[col1][0], suite.partitions[col1][1], suite.partitions[col1][2]
newReleasePartJob := func(col int64, partitions ...int64) *ReleasePartitionJob {
req := &querypb.ReleasePartitionsRequest{
CollectionID: col,
PartitionIDs: partitions,
}
job := NewReleasePartitionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
)
return job
}
newReleaseColJob := func(col int64) *ReleaseCollectionJob {
req := &querypb.ReleaseCollectionRequest{
CollectionID: col,
}
job := NewReleaseCollectionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.targetMgr,
suite.targetObserver,
)
return job
}
// loaded: p0, p1, p2
// action: release p0
// expect: p0 released, p1, p2 loaded
suite.loadAll()
job := newReleasePartJob(col0, p0)
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.assertPartitionReleased(col0, p0)
suite.assertPartitionLoaded(col0, p1, p2)
// loaded: p1, p2
// action: release p0, p1
// expect: p1 released, p2 loaded
job = newReleasePartJob(col0, p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.assertPartitionReleased(col0, p0, p1)
suite.assertPartitionLoaded(col0, p2)
// loaded: p2
// action: release p2
// expect: loadType=col: col loaded, p2 released
job = newReleasePartJob(col0, p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.assertPartitionReleased(col0, p0, p1, p2)
suite.True(suite.meta.Exist(col0))
// loaded: p0, p1, p2
// action: release col
// expect: col released
suite.releaseAll()
suite.loadAll()
releaseColJob := newReleaseColJob(col0)
suite.scheduler.Add(releaseColJob)
err = releaseColJob.Wait()
suite.NoError(err)
suite.assertCollectionReleased(col0)
suite.assertPartitionReleased(col0, p0, p1, p2)
// loaded: p3, p4, p5
// action: release p3, p4, p5
// expect: loadType=partition: col released
suite.releaseAll()
suite.loadAll()
job = newReleasePartJob(col1, p3, p4, p5)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.assertCollectionReleased(col1)
suite.assertPartitionReleased(col1, p3, p4, p5)
} }
func (suite *JobSuite) TestLoadCollectionStoreFailed() { func (suite *JobSuite) TestLoadCollectionStoreFailed() {
@ -818,14 +1065,10 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection { if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue continue
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
err := errors.New("failed to store collection") err := errors.New("failed to store collection")
store.EXPECT().SaveReplica(mock.Anything).Return(nil) store.EXPECT().SaveReplica(mock.Anything).Return(nil)
store.EXPECT().SaveCollection(&querypb.CollectionLoadInfo{ store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
CollectionID: collection,
ReplicaNumber: 1,
Status: querypb.LoadStatus_Loading,
}).Return(err)
store.EXPECT().ReleaseReplicas(collection).Return(nil) store.EXPECT().ReleaseReplicas(collection).Return(nil)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
@ -836,7 +1079,9 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -866,7 +1111,7 @@ func (suite *JobSuite) TestLoadPartitionStoreFailed() {
} }
store.EXPECT().SaveReplica(mock.Anything).Return(nil) store.EXPECT().SaveReplica(mock.Anything).Return(nil)
store.EXPECT().SavePartition(mock.Anything, mock.Anything).Return(err) store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
store.EXPECT().ReleaseReplicas(collection).Return(nil) store.EXPECT().ReleaseReplicas(collection).Return(nil)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
@ -878,7 +1123,9 @@ func (suite *JobSuite) TestLoadPartitionStoreFailed() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -892,6 +1139,9 @@ func (suite *JobSuite) TestLoadCreateReplicaFailed() {
// Store replica failed // Store replica failed
suite.meta = meta.NewMeta(ErrorIDAllocator(), suite.store, session.NewNodeManager()) suite.meta = meta.NewMeta(ErrorIDAllocator(), suite.store, session.NewNodeManager())
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().
GetPartitions(mock.Anything, collection).
Return(suite.partitions[collection], nil)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
CollectionID: collection, CollectionID: collection,
} }
@ -900,7 +1150,9 @@ func (suite *JobSuite) TestLoadCreateReplicaFailed() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -910,6 +1162,59 @@ func (suite *JobSuite) TestLoadCreateReplicaFailed() {
} }
} }
func (suite *JobSuite) TestSyncNewCreatedPartition() {
newPartition := int64(999)
// test sync new created partition
suite.loadAll()
req := &querypb.SyncNewCreatedPartitionRequest{
CollectionID: suite.collections[0],
PartitionID: newPartition,
}
job := NewSyncNewCreatedPartitionJob(
context.Background(),
req,
suite.meta,
suite.cluster,
)
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
partition := suite.meta.CollectionManager.GetPartition(newPartition)
suite.NotNil(partition)
suite.Equal(querypb.LoadStatus_Loaded, partition.GetStatus())
// test collection not loaded
req = &querypb.SyncNewCreatedPartitionRequest{
CollectionID: int64(888),
PartitionID: newPartition,
}
job = NewSyncNewCreatedPartitionJob(
context.Background(),
req,
suite.meta,
suite.cluster,
)
suite.scheduler.Add(job)
err = job.Wait()
suite.ErrorIs(err, ErrPartitionNotInTarget)
// test collection loaded, but its loadType is loadPartition
req = &querypb.SyncNewCreatedPartitionRequest{
CollectionID: suite.collections[1],
PartitionID: newPartition,
}
job = NewSyncNewCreatedPartitionJob(
context.Background(),
req,
suite.meta,
suite.cluster,
)
suite.scheduler.Add(job)
err = job.Wait()
suite.ErrorIs(err, ErrPartitionNotInTarget)
}
func (suite *JobSuite) loadAll() { func (suite *JobSuite) loadAll() {
ctx := context.Background() ctx := context.Background()
for _, collection := range suite.collections { for _, collection := range suite.collections {
@ -922,7 +1227,9 @@ func (suite *JobSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -932,6 +1239,7 @@ func (suite *JobSuite) loadAll() {
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection)) suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection)) suite.NotNil(suite.meta.GetCollection(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection) suite.targetMgr.UpdateCollectionCurrentTarget(collection)
} else { } else {
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
@ -943,7 +1251,9 @@ func (suite *JobSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -952,6 +1262,7 @@ func (suite *JobSuite) loadAll() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection)) suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection)) suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection) suite.targetMgr.UpdateCollectionCurrentTarget(collection)
} }
@ -975,24 +1286,43 @@ func (suite *JobSuite) releaseAll() {
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertCollectionReleased(collection)
} }
} }
func (suite *JobSuite) assertLoaded(collection int64) { func (suite *JobSuite) assertCollectionLoaded(collection int64) {
suite.True(suite.meta.Exist(collection)) suite.True(suite.meta.Exist(collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
} }
for _, partitions := range suite.segments[collection] { for _, segments := range suite.segments[collection] {
for _, segment := range partitions { for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
} }
} }
} }
func (suite *JobSuite) assertReleased(collection int64) { func (suite *JobSuite) assertPartitionLoaded(collection int64, partitionIDs ...int64) {
suite.True(suite.meta.Exist(collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
}
for partitionID, segments := range suite.segments[collection] {
if !lo.Contains(partitionIDs, partitionID) {
continue
}
suite.NotNil(suite.meta.GetPartition(partitionID))
for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
}
}
}
func (suite *JobSuite) assertCollectionReleased(collection int64) {
suite.False(suite.meta.Exist(collection)) suite.False(suite.meta.Exist(collection))
suite.Equal(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
} }
@ -1003,6 +1333,16 @@ func (suite *JobSuite) assertReleased(collection int64) {
} }
} }
func (suite *JobSuite) assertPartitionReleased(collection int64, partitionIDs ...int64) {
for _, partition := range partitionIDs {
suite.Nil(suite.meta.GetPartition(partition))
segments := suite.segments[collection][partition]
for _, segment := range segments {
suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
}
}
}
func TestJob(t *testing.T) { func TestJob(t *testing.T) {
suite.Run(t, new(JobSuite)) suite.Run(t, new(JobSuite))
} }

View File

@ -0,0 +1,68 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
)
type UndoList struct {
PartitionsLoaded bool // indicates if partitions loaded in QueryNodes during loading
TargetUpdated bool // indicates if target updated during loading
NewReplicaCreated bool // indicates if created new replicas during loading
CollectionID int64
LackPartitions []int64
ctx context.Context
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewUndoList(ctx context.Context, meta *meta.Meta,
cluster session.Cluster, targetMgr *meta.TargetManager, targetObserver *observers.TargetObserver) *UndoList {
return &UndoList{
ctx: ctx,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (u *UndoList) RollBack() {
if u.PartitionsLoaded {
releasePartitions(u.ctx, u.meta, u.cluster, true, u.CollectionID, u.LackPartitions...)
}
if u.TargetUpdated {
if !u.meta.CollectionManager.Exist(u.CollectionID) {
u.targetMgr.RemoveCollection(u.CollectionID)
u.targetObserver.ReleaseCollection(u.CollectionID)
} else {
u.targetMgr.RemovePartition(u.CollectionID, u.LackPartitions...)
}
}
if u.NewReplicaCreated {
u.meta.ReplicaManager.RemoveCollection(u.CollectionID)
}
}

View File

@ -17,11 +17,17 @@
package job package job
import ( import (
"context"
"fmt"
"time" "time"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
// waitCollectionReleased blocks until // waitCollectionReleased blocks until
@ -49,3 +55,57 @@ func waitCollectionReleased(dist *meta.DistributionManager, collection int64, pa
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
} }
} }
func loadPartitions(ctx context.Context, meta *meta.Meta, cluster session.Cluster,
ignoreErr bool, collection int64, partitions ...int64) error {
replicas := meta.ReplicaManager.GetByCollection(collection)
loadReq := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: collection,
PartitionIDs: partitions,
}
for _, replica := range replicas {
for _, node := range replica.GetNodes() {
status, err := cluster.LoadPartitions(ctx, node, loadReq)
if ignoreErr {
continue
}
if err != nil {
return err
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("QueryNode failed to loadPartition, nodeID=%d, err=%s", node, status.GetReason())
}
}
}
return nil
}
func releasePartitions(ctx context.Context, meta *meta.Meta, cluster session.Cluster,
ignoreErr bool, collection int64, partitions ...int64) error {
replicas := meta.ReplicaManager.GetByCollection(collection)
releaseReq := &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
CollectionID: collection,
PartitionIDs: partitions,
}
for _, replica := range replicas {
for _, node := range replica.GetNodes() {
status, err := cluster.ReleasePartitions(ctx, node, releaseReq)
if ignoreErr {
continue
}
if err != nil {
return err
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("QueryNode failed to releasePartitions, nodeID=%d, err=%s", node, status.GetReason())
}
}
}
return nil
}

View File

@ -17,15 +17,19 @@
package meta package meta
import ( import (
"context"
"sync" "sync"
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/merr"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
. "github.com/milvus-io/milvus/internal/util/typeutil" . "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
) )
type Collection struct { type Collection struct {
@ -72,7 +76,7 @@ func NewCollectionManager(store Store) *CollectionManager {
// Recover recovers collections from kv store, // Recover recovers collections from kv store,
// panics if failed // panics if failed
func (m *CollectionManager) Recover() error { func (m *CollectionManager) Recover(broker Broker) error {
collections, err := m.store.GetCollections() collections, err := m.store.GetCollections()
if err != nil { if err != nil {
return err return err
@ -88,7 +92,6 @@ func (m *CollectionManager) Recover() error {
m.store.ReleaseCollection(collection.GetCollectionID()) m.store.ReleaseCollection(collection.GetCollectionID())
continue continue
} }
m.collections[collection.CollectionID] = &Collection{ m.collections[collection.CollectionID] = &Collection{
CollectionLoadInfo: collection, CollectionLoadInfo: collection,
} }
@ -104,95 +107,172 @@ func (m *CollectionManager) Recover() error {
m.store.ReleasePartition(collection, partitionIDs...) m.store.ReleasePartition(collection, partitionIDs...)
break break
} }
m.partitions[partition.PartitionID] = &Partition{ m.partitions[partition.PartitionID] = &Partition{
PartitionLoadInfo: partition, PartitionLoadInfo: partition,
} }
} }
} }
err = m.upgradeRecover(broker)
if err != nil {
log.Error("upgrade recover failed", zap.Error(err))
return err
}
return nil return nil
} }
func (m *CollectionManager) GetCollection(id UniqueID) *Collection { // upgradeRecover recovers from old version <= 2.2.x for compatibility.
m.rwmutex.RLock() func (m *CollectionManager) upgradeRecover(broker Broker) error {
defer m.rwmutex.RUnlock() for _, collection := range m.GetAllCollections() {
// It's a workaround to check if it is old CollectionLoadInfo because there's no
return m.collections[id] // loadType in old version, maybe we should use version instead.
if collection.GetLoadType() == querypb.LoadType_UnKnownType {
partitionIDs, err := broker.GetPartitions(context.Background(), collection.GetCollectionID())
if err != nil {
return err
}
partitions := lo.Map(partitionIDs, func(partitionID int64, _ int) *Partition {
return &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection.GetCollectionID(),
PartitionID: partitionID,
Status: querypb.LoadStatus_Loaded,
},
LoadPercentage: 100,
}
})
err = m.putPartition(partitions, true)
if err != nil {
return err
}
}
}
for _, partition := range m.GetAllPartitions() {
// In old version, collection would NOT be stored if the partition existed.
if _, ok := m.collections[partition.GetCollectionID()]; !ok {
col := &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: partition.GetCollectionID(),
ReplicaNumber: partition.GetReplicaNumber(),
Status: partition.GetStatus(),
FieldIndexID: partition.GetFieldIndexID(),
LoadType: querypb.LoadType_LoadPartition,
},
LoadPercentage: 100,
}
err := m.PutCollection(col)
if err != nil {
return err
}
}
}
return nil
} }
func (m *CollectionManager) GetPartition(id UniqueID) *Partition { func (m *CollectionManager) GetCollection(collectionID UniqueID) *Collection {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
return m.partitions[id] return m.collections[collectionID]
} }
func (m *CollectionManager) GetLoadType(id UniqueID) querypb.LoadType { func (m *CollectionManager) GetPartition(partitionID UniqueID) *Partition {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
_, ok := m.collections[id] return m.partitions[partitionID]
}
func (m *CollectionManager) GetLoadType(collectionID UniqueID) querypb.LoadType {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
collection, ok := m.collections[collectionID]
if ok { if ok {
return querypb.LoadType_LoadCollection return collection.GetLoadType()
}
if len(m.getPartitionsByCollection(id)) > 0 {
return querypb.LoadType_LoadPartition
} }
return querypb.LoadType_UnKnownType return querypb.LoadType_UnKnownType
} }
func (m *CollectionManager) GetReplicaNumber(id UniqueID) int32 { func (m *CollectionManager) GetReplicaNumber(collectionID UniqueID) int32 {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
collection, ok := m.collections[id] collection, ok := m.collections[collectionID]
if ok { if ok {
return collection.GetReplicaNumber() return collection.GetReplicaNumber()
} }
partitions := m.getPartitionsByCollection(id)
if len(partitions) > 0 {
return partitions[0].GetReplicaNumber()
}
return -1 return -1
} }
func (m *CollectionManager) GetLoadPercentage(id UniqueID) int32 { // GetCurrentLoadPercentage checks if collection is currently fully loaded.
func (m *CollectionManager) GetCurrentLoadPercentage(collectionID UniqueID) int32 {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
collection, ok := m.collections[id] collection, ok := m.collections[collectionID]
if ok { if ok {
return collection.LoadPercentage partitions := m.getPartitionsByCollection(collectionID)
}
partitions := m.getPartitionsByCollection(id)
if len(partitions) > 0 { if len(partitions) > 0 {
return lo.SumBy(partitions, func(partition *Partition) int32 { return lo.SumBy(partitions, func(partition *Partition) int32 {
return partition.LoadPercentage return partition.LoadPercentage
}) / int32(len(partitions)) }) / int32(len(partitions))
} }
if collection.GetLoadType() == querypb.LoadType_LoadCollection {
// no partition exists
return 100
}
}
return -1 return -1
} }
func (m *CollectionManager) GetStatus(id UniqueID) querypb.LoadStatus { // GetCollectionLoadPercentage returns collection load percentage.
// Note: collection.LoadPercentage == 100 only means that it used to be fully loaded, and it is queryable,
// to check if it is fully loaded now, use GetCurrentLoadPercentage instead.
func (m *CollectionManager) GetCollectionLoadPercentage(collectionID UniqueID) int32 {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
collection, ok := m.collections[id] collection, ok := m.collections[collectionID]
if ok { if ok {
return collection.GetStatus() return collection.LoadPercentage
} }
partitions := m.getPartitionsByCollection(id) return -1
if len(partitions) == 0 { }
func (m *CollectionManager) GetPartitionLoadPercentage(partitionID UniqueID) int32 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
partition, ok := m.partitions[partitionID]
if ok {
return partition.LoadPercentage
}
return -1
}
func (m *CollectionManager) GetStatus(collectionID UniqueID) querypb.LoadStatus {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
collection, ok := m.collections[collectionID]
if !ok {
return querypb.LoadStatus_Invalid return querypb.LoadStatus_Invalid
} }
partitions := m.getPartitionsByCollection(collectionID)
for _, partition := range partitions { for _, partition := range partitions {
if partition.GetStatus() == querypb.LoadStatus_Loading { if partition.GetStatus() == querypb.LoadStatus_Loading {
return querypb.LoadStatus_Loading return querypb.LoadStatus_Loading
} }
} }
if len(partitions) > 0 {
return querypb.LoadStatus_Loaded return querypb.LoadStatus_Loaded
} }
if collection.GetLoadType() == querypb.LoadType_LoadCollection {
return querypb.LoadStatus_Loaded
}
return querypb.LoadStatus_Invalid
}
func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64 { func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64 {
m.rwmutex.RLock() m.rwmutex.RLock()
@ -202,12 +282,8 @@ func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64
if ok { if ok {
return collection.GetFieldIndexID() return collection.GetFieldIndexID()
} }
partitions := m.getPartitionsByCollection(collectionID)
if len(partitions) == 0 {
return nil return nil
} }
return partitions[0].GetFieldIndexID()
}
// ContainAnyIndex returns true if the loaded collection contains one of the given indexes, // ContainAnyIndex returns true if the loaded collection contains one of the given indexes,
// returns false otherwise. // returns false otherwise.
@ -228,31 +304,18 @@ func (m *CollectionManager) containIndex(collectionID, indexID int64) bool {
if ok { if ok {
return lo.Contains(lo.Values(collection.GetFieldIndexID()), indexID) return lo.Contains(lo.Values(collection.GetFieldIndexID()), indexID)
} }
partitions := m.getPartitionsByCollection(collectionID)
if len(partitions) == 0 {
return false
}
for _, partition := range partitions {
if lo.Contains(lo.Values(partition.GetFieldIndexID()), indexID) {
return true
}
}
return false return false
} }
func (m *CollectionManager) Exist(id UniqueID) bool { func (m *CollectionManager) Exist(collectionID UniqueID) bool {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
_, ok := m.collections[id] _, ok := m.collections[collectionID]
if ok { return ok
return true
}
partitions := m.getPartitionsByCollection(id)
return len(partitions) > 0
} }
// GetAll returns the collection ID of all loaded collections and partitions // GetAll returns the collection ID of all loaded collections
func (m *CollectionManager) GetAll() []int64 { func (m *CollectionManager) GetAll() []int64 {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
@ -261,9 +324,6 @@ func (m *CollectionManager) GetAll() []int64 {
for _, collection := range m.collections { for _, collection := range m.collections {
ids.Insert(collection.GetCollectionID()) ids.Insert(collection.GetCollectionID())
} }
for _, partition := range m.partitions {
ids.Insert(partition.GetCollectionID())
}
return ids.Collect() return ids.Collect()
} }
@ -298,11 +358,11 @@ func (m *CollectionManager) getPartitionsByCollection(collectionID UniqueID) []*
return partitions return partitions
} }
func (m *CollectionManager) PutCollection(collection *Collection) error { func (m *CollectionManager) PutCollection(collection *Collection, partitions ...*Partition) error {
m.rwmutex.Lock() m.rwmutex.Lock()
defer m.rwmutex.Unlock() defer m.rwmutex.Unlock()
return m.putCollection(collection, true) return m.putCollection(true, collection, partitions...)
} }
func (m *CollectionManager) UpdateCollection(collection *Collection) error { func (m *CollectionManager) UpdateCollection(collection *Collection) error {
@ -314,7 +374,7 @@ func (m *CollectionManager) UpdateCollection(collection *Collection) error {
return merr.WrapErrCollectionNotFound(collection.GetCollectionID()) return merr.WrapErrCollectionNotFound(collection.GetCollectionID())
} }
return m.putCollection(collection, true) return m.putCollection(true, collection)
} }
func (m *CollectionManager) UpdateCollectionInMemory(collection *Collection) bool { func (m *CollectionManager) UpdateCollectionInMemory(collection *Collection) bool {
@ -326,17 +386,24 @@ func (m *CollectionManager) UpdateCollectionInMemory(collection *Collection) boo
return false return false
} }
m.putCollection(collection, false) m.putCollection(false, collection)
return true return true
} }
func (m *CollectionManager) putCollection(collection *Collection, withSave bool) error { func (m *CollectionManager) putCollection(withSave bool, collection *Collection, partitions ...*Partition) error {
if withSave { if withSave {
err := m.store.SaveCollection(collection.CollectionLoadInfo) partitionInfos := lo.Map(partitions, func(partition *Partition, _ int) *querypb.PartitionLoadInfo {
return partition.PartitionLoadInfo
})
err := m.store.SaveCollection(collection.CollectionLoadInfo, partitionInfos...)
if err != nil { if err != nil {
return err return err
} }
} }
for _, partition := range partitions {
partition.UpdatedAt = time.Now()
m.partitions[partition.GetPartitionID()] = partition
}
collection.UpdatedAt = time.Now() collection.UpdatedAt = time.Now()
m.collections[collection.CollectionID] = collection m.collections[collection.CollectionID] = collection
@ -399,25 +466,25 @@ func (m *CollectionManager) putPartition(partitions []*Partition, withSave bool)
return nil return nil
} }
func (m *CollectionManager) RemoveCollection(id UniqueID) error { // RemoveCollection removes collection and its partitions.
func (m *CollectionManager) RemoveCollection(collectionID UniqueID) error {
m.rwmutex.Lock() m.rwmutex.Lock()
defer m.rwmutex.Unlock() defer m.rwmutex.Unlock()
_, ok := m.collections[id] _, ok := m.collections[collectionID]
if ok { if ok {
err := m.store.ReleaseCollection(id) err := m.store.ReleaseCollection(collectionID)
if err != nil { if err != nil {
return err return err
} }
delete(m.collections, id) delete(m.collections, collectionID)
return nil for partID, partition := range m.partitions {
if partition.CollectionID == collectionID {
delete(m.partitions, partID)
} }
}
partitions := lo.Map(m.getPartitionsByCollection(id), }
func(partition *Partition, _ int) int64 { return nil
return partition.GetPartitionID()
})
return m.removePartition(partitions...)
} }
func (m *CollectionManager) RemovePartition(ids ...UniqueID) error { func (m *CollectionManager) RemovePartition(ids ...UniqueID) error {

View File

@ -21,12 +21,15 @@ import (
"testing" "testing"
"time" "time"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
. "github.com/milvus-io/milvus/internal/querycoordv2/params" . "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/stretchr/testify/suite"
) )
type CollectionManagerSuite struct { type CollectionManagerSuite struct {
@ -37,11 +40,13 @@ type CollectionManagerSuite struct {
partitions map[int64][]int64 // CollectionID -> PartitionIDs partitions map[int64][]int64 // CollectionID -> PartitionIDs
loadTypes []querypb.LoadType loadTypes []querypb.LoadType
replicaNumber []int32 replicaNumber []int32
loadPercentage []int32 colLoadPercent []int32
parLoadPercent map[int64][]int32
// Mocks // Mocks
kv kv.MetaKv kv kv.MetaKv
store Store store Store
broker *MockBroker
// Test object // Test object
mgr *CollectionManager mgr *CollectionManager
@ -50,19 +55,27 @@ type CollectionManagerSuite struct {
func (suite *CollectionManagerSuite) SetupSuite() { func (suite *CollectionManagerSuite) SetupSuite() {
Params.Init() Params.Init()
suite.collections = []int64{100, 101, 102} suite.collections = []int64{100, 101, 102, 103}
suite.partitions = map[int64][]int64{ suite.partitions = map[int64][]int64{
100: {10}, 100: {10},
101: {11, 12}, 101: {11, 12},
102: {13, 14, 15}, 102: {13, 14, 15},
103: {}, // not partition in this col
} }
suite.loadTypes = []querypb.LoadType{ suite.loadTypes = []querypb.LoadType{
querypb.LoadType_LoadCollection, querypb.LoadType_LoadCollection,
querypb.LoadType_LoadPartition, querypb.LoadType_LoadPartition,
querypb.LoadType_LoadCollection, querypb.LoadType_LoadCollection,
querypb.LoadType_LoadCollection,
}
suite.replicaNumber = []int32{1, 2, 3, 1}
suite.colLoadPercent = []int32{0, 50, 100, 100}
suite.parLoadPercent = map[int64][]int32{
100: {0},
101: {0, 100},
102: {100, 100, 100},
103: {},
} }
suite.replicaNumber = []int32{1, 2, 3}
suite.loadPercentage = []int32{0, 50, 100}
} }
func (suite *CollectionManagerSuite) SetupTest() { func (suite *CollectionManagerSuite) SetupTest() {
@ -79,6 +92,7 @@ func (suite *CollectionManagerSuite) SetupTest() {
suite.Require().NoError(err) suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.store = NewMetaStore(suite.kv) suite.store = NewMetaStore(suite.kv)
suite.broker = NewMockBroker(suite.T())
suite.mgr = NewCollectionManager(suite.store) suite.mgr = NewCollectionManager(suite.store)
suite.loadAll() suite.loadAll()
@ -94,18 +108,18 @@ func (suite *CollectionManagerSuite) TestGetProperty() {
for i, collection := range suite.collections { for i, collection := range suite.collections {
loadType := mgr.GetLoadType(collection) loadType := mgr.GetLoadType(collection)
replicaNumber := mgr.GetReplicaNumber(collection) replicaNumber := mgr.GetReplicaNumber(collection)
percentage := mgr.GetLoadPercentage(collection) percentage := mgr.GetCurrentLoadPercentage(collection)
exist := mgr.Exist(collection) exist := mgr.Exist(collection)
suite.Equal(suite.loadTypes[i], loadType) suite.Equal(suite.loadTypes[i], loadType)
suite.Equal(suite.replicaNumber[i], replicaNumber) suite.Equal(suite.replicaNumber[i], replicaNumber)
suite.Equal(suite.loadPercentage[i], percentage) suite.Equal(suite.colLoadPercent[i], percentage)
suite.True(exist) suite.True(exist)
} }
invalidCollection := -1 invalidCollection := -1
loadType := mgr.GetLoadType(int64(invalidCollection)) loadType := mgr.GetLoadType(int64(invalidCollection))
replicaNumber := mgr.GetReplicaNumber(int64(invalidCollection)) replicaNumber := mgr.GetReplicaNumber(int64(invalidCollection))
percentage := mgr.GetLoadPercentage(int64(invalidCollection)) percentage := mgr.GetCurrentLoadPercentage(int64(invalidCollection))
exist := mgr.Exist(int64(invalidCollection)) exist := mgr.Exist(int64(invalidCollection))
suite.Equal(querypb.LoadType_UnKnownType, loadType) suite.Equal(querypb.LoadType_UnKnownType, loadType)
suite.EqualValues(-1, replicaNumber) suite.EqualValues(-1, replicaNumber)
@ -113,33 +127,45 @@ func (suite *CollectionManagerSuite) TestGetProperty() {
suite.False(exist) suite.False(exist)
} }
func (suite *CollectionManagerSuite) TestPut() {
suite.releaseAll()
// test put collection with partitions
for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded
if suite.colLoadPercent[i] < 100 {
status = querypb.LoadStatus_Loading
}
col := &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i],
Status: status,
LoadType: suite.loadTypes[i],
},
LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(),
}
partitions := lo.Map(suite.partitions[collection], func(partition int64, j int) *Partition {
return &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
ReplicaNumber: suite.replicaNumber[i],
Status: status,
},
LoadPercentage: suite.parLoadPercent[collection][j],
CreatedAt: time.Now(),
}
})
err := suite.mgr.PutCollection(col, partitions...)
suite.NoError(err)
}
suite.checkLoadResult()
}
func (suite *CollectionManagerSuite) TestGet() { func (suite *CollectionManagerSuite) TestGet() {
mgr := suite.mgr suite.checkLoadResult()
allCollections := mgr.GetAllCollections()
allPartitions := mgr.GetAllPartitions()
for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
collection := mgr.GetCollection(collectionID)
suite.Equal(collectionID, collection.GetCollectionID())
suite.Contains(allCollections, collection)
} else {
partitions := mgr.GetPartitionsByCollection(collectionID)
suite.Len(partitions, len(suite.partitions[collectionID]))
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
suite.Equal(collectionID, partition.GetCollectionID())
suite.Equal(partitionID, partition.GetPartitionID())
suite.Contains(partitions, partition)
suite.Contains(allPartitions, partition)
}
}
}
all := mgr.GetAll()
sort.Slice(all, func(i, j int) bool { return all[i] < all[j] })
suite.Equal(suite.collections, all)
} }
func (suite *CollectionManagerSuite) TestUpdate() { func (suite *CollectionManagerSuite) TestUpdate() {
@ -177,7 +203,7 @@ func (suite *CollectionManagerSuite) TestUpdate() {
} }
suite.clearMemory() suite.clearMemory()
err := mgr.Recover() err := mgr.Recover(suite.broker)
suite.NoError(err) suite.NoError(err)
collections = mgr.GetAllCollections() collections = mgr.GetAllCollections()
partitions = mgr.GetAllPartitions() partitions = mgr.GetAllPartitions()
@ -215,7 +241,7 @@ func (suite *CollectionManagerSuite) TestRemove() {
} }
// Make sure the removes applied to meta store // Make sure the removes applied to meta store
err := mgr.Recover() err := mgr.Recover(suite.broker)
suite.NoError(err) suite.NoError(err)
for i, collectionID := range suite.collections { for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection { if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
@ -237,37 +263,50 @@ func (suite *CollectionManagerSuite) TestRemove() {
suite.Empty(partitions) suite.Empty(partitions)
} }
} }
// remove collection would release its partitions also
suite.releaseAll()
suite.loadAll()
for _, collectionID := range suite.collections {
err := mgr.RemoveCollection(collectionID)
suite.NoError(err)
err = mgr.Recover(suite.broker)
suite.NoError(err)
collection := mgr.GetCollection(collectionID)
suite.Nil(collection)
partitions := mgr.GetPartitionsByCollection(collectionID)
suite.Empty(partitions)
}
} }
func (suite *CollectionManagerSuite) TestRecover() { func (suite *CollectionManagerSuite) TestRecover() {
mgr := suite.mgr mgr := suite.mgr
suite.clearMemory() suite.clearMemory()
err := mgr.Recover() err := mgr.Recover(suite.broker)
suite.NoError(err) suite.NoError(err)
for i, collection := range suite.collections { for i, collection := range suite.collections {
exist := suite.loadPercentage[i] == 100 exist := suite.colLoadPercent[i] == 100
suite.Equal(exist, mgr.Exist(collection)) suite.Equal(exist, mgr.Exist(collection))
} }
} }
func (suite *CollectionManagerSuite) loadAll() { func (suite *CollectionManagerSuite) TestUpgradeRecover() {
suite.releaseAll()
mgr := suite.mgr mgr := suite.mgr
// put old version of collections and partitions
for i, collection := range suite.collections { for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded status := querypb.LoadStatus_Loaded
if suite.loadPercentage[i] < 100 {
status = querypb.LoadStatus_Loading
}
if suite.loadTypes[i] == querypb.LoadType_LoadCollection { if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
mgr.PutCollection(&Collection{ mgr.PutCollection(&Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection, CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i], ReplicaNumber: suite.replicaNumber[i],
Status: status, Status: status,
LoadType: querypb.LoadType_UnKnownType, // old version's collection didn't set loadType
}, },
LoadPercentage: suite.loadPercentage[i], LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} else { } else {
@ -279,12 +318,92 @@ func (suite *CollectionManagerSuite) loadAll() {
ReplicaNumber: suite.replicaNumber[i], ReplicaNumber: suite.replicaNumber[i],
Status: status, Status: status,
}, },
LoadPercentage: suite.loadPercentage[i], LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} }
} }
} }
// set expectations
for i, collection := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
}
}
// do recovery
suite.clearMemory()
err := mgr.Recover(suite.broker)
suite.NoError(err)
suite.checkLoadResult()
}
func (suite *CollectionManagerSuite) loadAll() {
mgr := suite.mgr
for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded
if suite.colLoadPercent[i] < 100 {
status = querypb.LoadStatus_Loading
}
mgr.PutCollection(&Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i],
Status: status,
LoadType: suite.loadTypes[i],
},
LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(),
})
for j, partition := range suite.partitions[collection] {
mgr.PutPartition(&Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
Status: status,
},
LoadPercentage: suite.parLoadPercent[collection][j],
CreatedAt: time.Now(),
})
}
}
}
func (suite *CollectionManagerSuite) checkLoadResult() {
mgr := suite.mgr
allCollections := mgr.GetAllCollections()
allPartitions := mgr.GetAllPartitions()
for _, collectionID := range suite.collections {
collection := mgr.GetCollection(collectionID)
suite.Equal(collectionID, collection.GetCollectionID())
suite.Contains(allCollections, collection)
partitions := mgr.GetPartitionsByCollection(collectionID)
suite.Len(partitions, len(suite.partitions[collectionID]))
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
suite.Equal(collectionID, partition.GetCollectionID())
suite.Equal(partitionID, partition.GetPartitionID())
suite.Contains(partitions, partition)
suite.Contains(allPartitions, partition)
}
}
all := mgr.GetAll()
sort.Slice(all, func(i, j int) bool { return all[i] < all[j] })
suite.Equal(suite.collections, all)
}
func (suite *CollectionManagerSuite) releaseAll() {
for _, collection := range suite.collections {
err := suite.mgr.RemoveCollection(collection)
suite.NoError(err)
}
} }
func (suite *CollectionManagerSuite) clearMemory() { func (suite *CollectionManagerSuite) clearMemory() {

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package meta package meta
@ -200,13 +200,13 @@ func (_c *MockStore_GetResourceGroups_Call) Return(_a0 []*querypb.ResourceGroup,
return _c return _c
} }
// ReleaseCollection provides a mock function with given fields: id // ReleaseCollection provides a mock function with given fields: collection
func (_m *MockStore) ReleaseCollection(id int64) error { func (_m *MockStore) ReleaseCollection(collection int64) error {
ret := _m.Called(id) ret := _m.Called(collection)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(int64) error); ok { if rf, ok := ret.Get(0).(func(int64) error); ok {
r0 = rf(id) r0 = rf(collection)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -220,12 +220,12 @@ type MockStore_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - id int64 // - collection int64
func (_e *MockStore_Expecter) ReleaseCollection(id interface{}) *MockStore_ReleaseCollection_Call { func (_e *MockStore_Expecter) ReleaseCollection(collection interface{}) *MockStore_ReleaseCollection_Call {
return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", id)} return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)}
} }
func (_c *MockStore_ReleaseCollection_Call) Run(run func(id int64)) *MockStore_ReleaseCollection_Call { func (_c *MockStore_ReleaseCollection_Call) Run(run func(collection int64)) *MockStore_ReleaseCollection_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64)) run(args[0].(int64))
}) })
@ -401,13 +401,20 @@ func (_c *MockStore_RemoveResourceGroup_Call) Return(_a0 error) *MockStore_Remov
return _c return _c
} }
// SaveCollection provides a mock function with given fields: info // SaveCollection provides a mock function with given fields: collection, partitions
func (_m *MockStore) SaveCollection(info *querypb.CollectionLoadInfo) error { func (_m *MockStore) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
ret := _m.Called(info) _va := make([]interface{}, len(partitions))
for _i := range partitions {
_va[_i] = partitions[_i]
}
var _ca []interface{}
_ca = append(_ca, collection)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(*querypb.CollectionLoadInfo) error); ok { if rf, ok := ret.Get(0).(func(*querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error); ok {
r0 = rf(info) r0 = rf(collection, partitions...)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -421,14 +428,22 @@ type MockStore_SaveCollection_Call struct {
} }
// SaveCollection is a helper method to define mock.On call // SaveCollection is a helper method to define mock.On call
// - info *querypb.CollectionLoadInfo // - collection *querypb.CollectionLoadInfo
func (_e *MockStore_Expecter) SaveCollection(info interface{}) *MockStore_SaveCollection_Call { // - partitions ...*querypb.PartitionLoadInfo
return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection", info)} func (_e *MockStore_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *MockStore_SaveCollection_Call {
return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection",
append([]interface{}{collection}, partitions...)...)}
} }
func (_c *MockStore_SaveCollection_Call) Run(run func(info *querypb.CollectionLoadInfo)) *MockStore_SaveCollection_Call { func (_c *MockStore_SaveCollection_Call) Run(run func(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo)) *MockStore_SaveCollection_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*querypb.CollectionLoadInfo)) variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(*querypb.PartitionLoadInfo)
}
}
run(args[0].(*querypb.CollectionLoadInfo), variadicArgs...)
}) })
return _c return _c
} }

View File

@ -61,13 +61,23 @@ func NewMetaStore(cli kv.MetaKv) metaStore {
} }
} }
func (s metaStore) SaveCollection(info *querypb.CollectionLoadInfo) error { func (s metaStore) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
k := encodeCollectionLoadInfoKey(info.GetCollectionID()) k := encodeCollectionLoadInfoKey(collection.GetCollectionID())
v, err := proto.Marshal(info) v, err := proto.Marshal(collection)
if err != nil { if err != nil {
return err return err
} }
return s.cli.Save(k, string(v)) kvs := make(map[string]string)
for _, partition := range partitions {
key := encodePartitionLoadInfoKey(partition.GetCollectionID(), partition.GetPartitionID())
value, err := proto.Marshal(partition)
if err != nil {
return err
}
kvs[key] = string(value)
}
kvs[k] = string(v)
return s.cli.MultiSave(kvs)
} }
func (s metaStore) SavePartition(info ...*querypb.PartitionLoadInfo) error { func (s metaStore) SavePartition(info ...*querypb.PartitionLoadInfo) error {
@ -211,9 +221,27 @@ func (s metaStore) GetResourceGroups() ([]*querypb.ResourceGroup, error) {
return ret, nil return ret, nil
} }
func (s metaStore) ReleaseCollection(id int64) error { func (s metaStore) ReleaseCollection(collection int64) error {
k := encodeCollectionLoadInfoKey(id) // obtain partitions of this collection
return s.cli.Remove(k) _, values, err := s.cli.LoadWithPrefix(fmt.Sprintf("%s/%d", PartitionLoadInfoPrefix, collection))
if err != nil {
return err
}
partitions := make([]*querypb.PartitionLoadInfo, 0)
for _, v := range values {
info := querypb.PartitionLoadInfo{}
if err = proto.Unmarshal([]byte(v), &info); err != nil {
return err
}
partitions = append(partitions, &info)
}
// remove collection and obtained partitions
keys := lo.Map(partitions, func(partition *querypb.PartitionLoadInfo, _ int) string {
return encodePartitionLoadInfoKey(collection, partition.GetPartitionID())
})
k := encodeCollectionLoadInfoKey(collection)
keys = append(keys, k)
return s.cli.MultiRemove(keys)
} }
func (s metaStore) ReleasePartition(collection int64, partitions ...int64) error { func (s metaStore) ReleasePartition(collection int64, partitions ...int64) error {

View File

@ -81,6 +81,39 @@ func (suite *StoreTestSuite) TestCollection() {
suite.Len(collections, 1) suite.Len(collections, 1)
} }
func (suite *StoreTestSuite) TestCollectionWithPartition() {
suite.store.SaveCollection(&querypb.CollectionLoadInfo{
CollectionID: 1,
})
suite.store.SaveCollection(&querypb.CollectionLoadInfo{
CollectionID: 2,
}, &querypb.PartitionLoadInfo{
CollectionID: 2,
PartitionID: 102,
})
suite.store.SaveCollection(&querypb.CollectionLoadInfo{
CollectionID: 3,
}, &querypb.PartitionLoadInfo{
CollectionID: 3,
PartitionID: 103,
})
suite.store.ReleaseCollection(1)
suite.store.ReleaseCollection(2)
collections, err := suite.store.GetCollections()
suite.NoError(err)
suite.Len(collections, 1)
suite.Equal(int64(3), collections[0].GetCollectionID())
partitions, err := suite.store.GetPartitions()
suite.NoError(err)
suite.Len(partitions, 1)
suite.Len(partitions[int64(3)], 1)
suite.Equal(int64(103), partitions[int64(3)][0].GetPartitionID())
}
func (suite *StoreTestSuite) TestPartition() { func (suite *StoreTestSuite) TestPartition() {
suite.store.SavePartition(&querypb.PartitionLoadInfo{ suite.store.SavePartition(&querypb.PartitionLoadInfo{
PartitionID: 1, PartitionID: 1,

View File

@ -106,22 +106,10 @@ func (mgr *TargetManager) UpdateCollectionNextTarget(collectionID int64) error {
mgr.rwMutex.Lock() mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock() defer mgr.rwMutex.Unlock()
partitionIDs := make([]int64, 0)
collection := mgr.meta.GetCollection(collectionID)
if collection != nil {
var err error
partitionIDs, err = mgr.broker.GetPartitions(context.Background(), collectionID)
if err != nil {
return err
}
} else {
partitions := mgr.meta.GetPartitionsByCollection(collectionID) partitions := mgr.meta.GetPartitionsByCollection(collectionID)
if partitions != nil { partitionIDs := lo.Map(partitions, func(partition *Partition, i int) int64 {
partitionIDs = lo.Map(partitions, func(partition *Partition, i int) int64 {
return partition.PartitionID return partition.PartitionID
}) })
}
}
return mgr.updateCollectionNextTarget(collectionID, partitionIDs...) return mgr.updateCollectionNextTarget(collectionID, partitionIDs...)
} }
@ -146,14 +134,27 @@ func (mgr *TargetManager) updateCollectionNextTarget(collectionID int64, partiti
return nil return nil
} }
func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, partitionIDs ...int64) (*CollectionTarget, error) { func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (*CollectionTarget, error) {
log.Info("start to pull next targets for partition", log.Info("start to pull next targets for partition",
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", partitionIDs)) zap.Int64s("chosenPartitionIDs", chosenPartitionIDs))
channelInfos := make(map[string][]*datapb.VchannelInfo) channelInfos := make(map[string][]*datapb.VchannelInfo)
segments := make(map[int64]*datapb.SegmentInfo, 0) segments := make(map[int64]*datapb.SegmentInfo, 0)
for _, partitionID := range partitionIDs { dmChannels := make(map[string]*DmChannel)
if len(chosenPartitionIDs) == 0 {
return NewCollectionTarget(segments, dmChannels), nil
}
fullPartitions, err := broker.GetPartitions(context.Background(), collectionID)
if err != nil {
return nil, err
}
// we should pull `channel targets` from all partitions because QueryNodes need to load
// the complete growing segments. And we should pull `segments targets` only from the chosen partitions.
for _, partitionID := range fullPartitions {
log.Debug("get recovery info...", log.Debug("get recovery info...",
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64("partitionID", partitionID)) zap.Int64("partitionID", partitionID))
@ -161,7 +162,12 @@ func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, part
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, info := range vChannelInfos {
channelInfos[info.GetChannelName()] = append(channelInfos[info.GetChannelName()], info)
}
if !lo.Contains(chosenPartitionIDs, partitionID) {
continue
}
for _, binlog := range binlogs { for _, binlog := range binlogs {
segments[binlog.GetSegmentID()] = &datapb.SegmentInfo{ segments[binlog.GetSegmentID()] = &datapb.SegmentInfo{
ID: binlog.GetSegmentID(), ID: binlog.GetSegmentID(),
@ -174,18 +180,12 @@ func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, part
Deltalogs: binlog.GetDeltalogs(), Deltalogs: binlog.GetDeltalogs(),
} }
} }
for _, info := range vChannelInfos {
channelInfos[info.GetChannelName()] = append(channelInfos[info.GetChannelName()], info)
}
} }
dmChannels := make(map[string]*DmChannel)
for _, infos := range channelInfos { for _, infos := range channelInfos {
merged := mgr.mergeDmChannelInfo(infos) merged := mgr.mergeDmChannelInfo(infos)
dmChannels[merged.GetChannelName()] = merged dmChannels[merged.GetChannelName()] = merged
} }
return NewCollectionTarget(segments, dmChannels), nil return NewCollectionTarget(segments, dmChannels), nil
} }

View File

@ -126,7 +126,7 @@ func (suite *TargetManagerSuite) SetupTest() {
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments, nil)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
suite.mgr.UpdateCollectionNextTargetWithPartitions(collection, suite.partitions[collection]...) suite.mgr.UpdateCollectionNextTargetWithPartitions(collection, suite.partitions[collection]...)
} }
} }
@ -192,6 +192,7 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() {
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collectionID, int64(1)).Return(nextTargetChannels, nextTargetSegments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collectionID, int64(1)).Return(nextTargetChannels, nextTargetSegments, nil)
suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID, int64(1)) suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID, int64(1))
suite.assertSegments([]int64{11, 12}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) suite.assertSegments([]int64{11, 12}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget))

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package mocks package mocks
@ -358,6 +358,53 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) Return(_a0 *milvuspb.Stri
return _c return _c
} }
// LoadPartitions provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNodeServer_LoadPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadPartitions'
type MockQueryNodeServer_LoadPartitions_Call struct {
*mock.Call
}
// LoadPartitions is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.LoadPartitionsRequest
func (_e *MockQueryNodeServer_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadPartitions_Call {
return &MockQueryNodeServer_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)}
}
func (_c *MockQueryNodeServer_LoadPartitions_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest)) *MockQueryNodeServer_LoadPartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.LoadPartitionsRequest))
})
return _c
}
func (_c *MockQueryNodeServer_LoadPartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeServer_LoadPartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// LoadSegments provides a mock function with given fields: _a0, _a1 // LoadSegments provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)

View File

@ -37,7 +37,6 @@ type CollectionObserver struct {
meta *meta.Meta meta *meta.Meta
targetMgr *meta.TargetManager targetMgr *meta.TargetManager
targetObserver *TargetObserver targetObserver *TargetObserver
collectionLoadedCount map[int64]int
partitionLoadedCount map[int64]int partitionLoadedCount map[int64]int
stopOnce sync.Once stopOnce sync.Once
@ -55,7 +54,6 @@ func NewCollectionObserver(
meta: meta, meta: meta,
targetMgr: targetMgr, targetMgr: targetMgr,
targetObserver: targetObserver, targetObserver: targetObserver,
collectionLoadedCount: make(map[int64]int),
partitionLoadedCount: make(map[int64]int), partitionLoadedCount: make(map[int64]int),
} }
} }
@ -115,36 +113,24 @@ func (ob *CollectionObserver) observeTimeout() {
log.Info("observes partitions timeout", zap.Int("partitionNum", len(partitions))) log.Info("observes partitions timeout", zap.Int("partitionNum", len(partitions)))
} }
for collection, partitions := range partitions { for collection, partitions := range partitions {
log := log.With(
zap.Int64("collectionID", collection),
)
for _, partition := range partitions { for _, partition := range partitions {
if partition.GetStatus() != querypb.LoadStatus_Loading || if partition.GetStatus() != querypb.LoadStatus_Loading ||
time.Now().Before(partition.UpdatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds.GetAsDuration(time.Second))) { time.Now().Before(partition.UpdatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds.GetAsDuration(time.Second))) {
continue continue
} }
log.Info("load partition timeout, cancel all partitions", log.Info("load partition timeout, cancel it",
zap.Int64("collectionID", collection),
zap.Int64("partitionID", partition.GetPartitionID()), zap.Int64("partitionID", partition.GetPartitionID()),
zap.Duration("loadTime", time.Since(partition.CreatedAt))) zap.Duration("loadTime", time.Since(partition.CreatedAt)))
// TODO(yah01): Now, releasing part of partitions is not allowed ob.meta.CollectionManager.RemovePartition(partition.GetPartitionID())
ob.meta.CollectionManager.RemoveCollection(partition.GetCollectionID()) ob.targetMgr.RemovePartition(partition.GetCollectionID(), partition.GetPartitionID())
ob.meta.ReplicaManager.RemoveCollection(partition.GetCollectionID())
ob.targetMgr.RemoveCollection(partition.GetCollectionID())
break break
} }
} }
} }
func (ob *CollectionObserver) observeLoadStatus() { func (ob *CollectionObserver) observeLoadStatus() {
collections := ob.meta.CollectionManager.GetAllCollections()
for _, collection := range collections {
if collection.LoadPercentage == 100 {
continue
}
ob.observeCollectionLoadStatus(collection)
}
partitions := ob.meta.CollectionManager.GetAllPartitions() partitions := ob.meta.CollectionManager.GetAllPartitions()
if len(partitions) > 0 { if len(partitions) > 0 {
log.Info("observe partitions status", zap.Int("partitionNum", len(partitions))) log.Info("observe partitions status", zap.Int("partitionNum", len(partitions)))
@ -153,61 +139,30 @@ func (ob *CollectionObserver) observeLoadStatus() {
if partition.LoadPercentage == 100 { if partition.LoadPercentage == 100 {
continue continue
} }
ob.observePartitionLoadStatus(partition) replicaNum := ob.meta.GetReplicaNumber(partition.GetCollectionID())
ob.observePartitionLoadStatus(partition, replicaNum)
}
collections := ob.meta.CollectionManager.GetAllCollections()
for _, collection := range collections {
if collection.LoadPercentage == 100 {
continue
}
ob.observeCollectionLoadStatus(collection)
} }
} }
func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Collection) { func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Collection) {
log := log.With(zap.Int64("collectionID", collection.GetCollectionID())) log := log.With(zap.Int64("collectionID", collection.GetCollectionID()))
segmentTargets := ob.targetMgr.GetHistoricalSegmentsByCollection(collection.GetCollectionID(), meta.NextTarget)
channelTargets := ob.targetMgr.GetDmChannelsByCollection(collection.GetCollectionID(), meta.NextTarget)
targetNum := len(segmentTargets) + len(channelTargets)
log.Info("collection targets",
zap.Int("segmentTargetNum", len(segmentTargets)),
zap.Int("channelTargetNum", len(channelTargets)),
zap.Int("totalTargetNum", targetNum),
zap.Int32("replicaNum", collection.GetReplicaNumber()),
)
updated := collection.Clone() updated := collection.Clone()
loadedCount := 0 percentage := ob.meta.CollectionManager.GetCurrentLoadPercentage(collection.GetCollectionID())
if targetNum == 0 { if percentage <= updated.LoadPercentage {
log.Info("No segment/channel in target need to be loaded!")
updated.LoadPercentage = 100
} else {
for _, channel := range channelTargets {
group := utils.GroupNodesByReplica(ob.meta.ReplicaManager,
collection.GetCollectionID(),
ob.dist.LeaderViewManager.GetChannelDist(channel.GetChannelName()))
loadedCount += len(group)
}
subChannelCount := loadedCount
for _, segment := range segmentTargets {
group := utils.GroupNodesByReplica(ob.meta.ReplicaManager,
collection.GetCollectionID(),
ob.dist.LeaderViewManager.GetSealedSegmentDist(segment.GetID()))
loadedCount += len(group)
}
if loadedCount > 0 {
log.Info("collection load progress",
zap.Int("subChannelCount", subChannelCount),
zap.Int("loadSegmentCount", loadedCount-subChannelCount),
)
}
updated.LoadPercentage = int32(loadedCount * 100 / (targetNum * int(collection.GetReplicaNumber())))
}
if loadedCount <= ob.collectionLoadedCount[collection.GetCollectionID()] &&
updated.LoadPercentage != 100 {
ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount
return return
} }
ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount updated.LoadPercentage = percentage
if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) { if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) {
delete(ob.collectionLoadedCount, collection.GetCollectionID())
updated.Status = querypb.LoadStatus_Loaded updated.Status = querypb.LoadStatus_Loaded
ob.meta.CollectionManager.UpdateCollection(updated) ob.meta.CollectionManager.UpdateCollection(updated)
@ -221,7 +176,7 @@ func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Colle
zap.Int32("collectionStatus", int32(updated.GetStatus()))) zap.Int32("collectionStatus", int32(updated.GetStatus())))
} }
func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partition) { func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partition, replicaNum int32) {
log := log.With( log := log.With(
zap.Int64("collectionID", partition.GetCollectionID()), zap.Int64("collectionID", partition.GetCollectionID()),
zap.Int64("partitionID", partition.GetPartitionID()), zap.Int64("partitionID", partition.GetPartitionID()),
@ -234,7 +189,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti
zap.Int("segmentTargetNum", len(segmentTargets)), zap.Int("segmentTargetNum", len(segmentTargets)),
zap.Int("channelTargetNum", len(channelTargets)), zap.Int("channelTargetNum", len(channelTargets)),
zap.Int("totalTargetNum", targetNum), zap.Int("totalTargetNum", targetNum),
zap.Int32("replicaNum", partition.GetReplicaNumber()), zap.Int32("replicaNum", replicaNum),
) )
loadedCount := 0 loadedCount := 0
@ -261,7 +216,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti
zap.Int("subChannelCount", subChannelCount), zap.Int("subChannelCount", subChannelCount),
zap.Int("loadSegmentCount", loadedCount-subChannelCount)) zap.Int("loadSegmentCount", loadedCount-subChannelCount))
} }
updated.LoadPercentage = int32(loadedCount * 100 / (targetNum * int(partition.GetReplicaNumber()))) updated.LoadPercentage = int32(loadedCount * 100 / (targetNum * int(replicaNum)))
} }
if loadedCount <= ob.partitionLoadedCount[partition.GetPartitionID()] && if loadedCount <= ob.partitionLoadedCount[partition.GetPartitionID()] &&

View File

@ -21,6 +21,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/samber/lo"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
@ -200,7 +201,7 @@ func (suite *CollectionObserverSuite) SetupTest() {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
} }
suite.targetObserver.Start(context.Background()) suite.targetObserver.Start(context.Background())
suite.ob.Start(context.Background())
suite.loadAll() suite.loadAll()
} }
@ -212,22 +213,12 @@ func (suite *CollectionObserverSuite) TearDownTest() {
func (suite *CollectionObserverSuite) TestObserve() { func (suite *CollectionObserverSuite) TestObserve() {
const ( const (
timeout = 2 * time.Second timeout = 3 * time.Second
) )
// time before load // time before load
time := suite.meta.GetCollection(suite.collections[2]).UpdatedAt time := suite.meta.GetCollection(suite.collections[2]).UpdatedAt
// Not timeout // Not timeout
paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "2") paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "3")
segments := []*datapb.SegmentBinlogs{}
for _, segment := range suite.segments[100] {
segments = append(segments, &datapb.SegmentBinlogs{
SegmentID: segment.GetID(),
InsertChannel: segment.GetInsertChannel(),
})
}
suite.ob.Start(context.Background())
// Collection 100 loaded before timeout, // Collection 100 loaded before timeout,
// collection 101 timeout // collection 101 timeout
@ -282,9 +273,45 @@ func (suite *CollectionObserverSuite) TestObserve() {
}, timeout*2, timeout/10) }, timeout*2, timeout/10)
} }
func (suite *CollectionObserverSuite) TestObservePartition() {
const (
timeout = 3 * time.Second
)
paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "3")
// Partition 10 loaded
suite.dist.LeaderViewManager.Update(1, &meta.LeaderView{
ID: 1,
CollectionID: 100,
Channel: "100-dmc0",
Segments: map[int64]*querypb.SegmentDist{1: {NodeID: 1, Version: 0}},
})
suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{
ID: 2,
CollectionID: 100,
Channel: "100-dmc1",
Segments: map[int64]*querypb.SegmentDist{2: {NodeID: 2, Version: 0}},
})
// Partition 11 timeout
suite.dist.LeaderViewManager.Update(1, &meta.LeaderView{
ID: 1,
CollectionID: 101,
Channel: "",
Segments: map[int64]*querypb.SegmentDist{},
})
suite.Eventually(func() bool {
return suite.isPartitionLoaded(suite.partitions[100][0])
}, timeout*2, timeout/10)
suite.Eventually(func() bool {
return suite.isPartitionTimeout(suite.collections[1], suite.partitions[101][0])
}, timeout*2, timeout/10)
}
func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool { func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool {
exist := suite.meta.Exist(collection) exist := suite.meta.Exist(collection)
percentage := suite.meta.GetLoadPercentage(collection) percentage := suite.meta.GetCurrentLoadPercentage(collection)
status := suite.meta.GetStatus(collection) status := suite.meta.GetStatus(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(collection) replicas := suite.meta.ReplicaManager.GetByCollection(collection)
channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget)
@ -298,6 +325,25 @@ func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool
len(segments) == len(suite.segments[collection]) len(segments) == len(suite.segments[collection])
} }
func (suite *CollectionObserverSuite) isPartitionLoaded(partitionID int64) bool {
partition := suite.meta.GetPartition(partitionID)
if partition == nil {
return false
}
collection := partition.GetCollectionID()
percentage := suite.meta.GetPartitionLoadPercentage(partitionID)
status := partition.GetStatus()
channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget)
segments := suite.targetMgr.GetHistoricalSegmentsByPartition(collection, partitionID, meta.CurrentTarget)
expectedSegments := lo.Filter(suite.segments[collection], func(seg *datapb.SegmentInfo, _ int) bool {
return seg.PartitionID == partitionID
})
return percentage == 100 &&
status == querypb.LoadStatus_Loaded &&
len(channels) == len(suite.channels[collection]) &&
len(segments) == len(expectedSegments)
}
func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool { func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool {
exist := suite.meta.Exist(collection) exist := suite.meta.Exist(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(collection) replicas := suite.meta.ReplicaManager.GetByCollection(collection)
@ -309,9 +355,14 @@ func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool
len(segments) > 0) len(segments) > 0)
} }
func (suite *CollectionObserverSuite) isPartitionTimeout(collection int64, partitionID int64) bool {
partition := suite.meta.GetPartition(partitionID)
segments := suite.targetMgr.GetHistoricalSegmentsByPartition(collection, partitionID, meta.CurrentTarget)
return partition == nil && len(segments) == 0
}
func (suite *CollectionObserverSuite) isCollectionLoadedContinue(collection int64, beforeTime time.Time) bool { func (suite *CollectionObserverSuite) isCollectionLoadedContinue(collection int64, beforeTime time.Time) bool {
return suite.meta.GetCollection(collection).UpdatedAt.After(beforeTime) return suite.meta.GetCollection(collection).UpdatedAt.After(beforeTime)
} }
func (suite *CollectionObserverSuite) loadAll() { func (suite *CollectionObserverSuite) loadAll() {
@ -332,17 +383,17 @@ func (suite *CollectionObserverSuite) load(collection int64) {
err = suite.meta.ReplicaManager.Put(replicas...) err = suite.meta.ReplicaManager.Put(replicas...)
suite.NoError(err) suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.meta.PutCollection(&meta.Collection{ suite.meta.PutCollection(&meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection, CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection], ReplicaNumber: suite.replicaNumber[collection],
Status: querypb.LoadStatus_Loading, Status: querypb.LoadStatus_Loading,
LoadType: suite.loadTypes[collection],
}, },
LoadPercentage: 0, LoadPercentage: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} else {
for _, partition := range suite.partitions[collection] { for _, partition := range suite.partitions[collection] {
suite.meta.PutPartition(&meta.Partition{ suite.meta.PutPartition(&meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{ PartitionLoadInfo: &querypb.PartitionLoadInfo{
@ -355,9 +406,8 @@ func (suite *CollectionObserverSuite) load(collection int64) {
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} }
}
allSegments := make([]*datapb.SegmentBinlogs, 0) allSegments := make(map[int64][]*datapb.SegmentBinlogs, 0) // partitionID -> segments
dmChannels := make([]*datapb.VchannelInfo, 0) dmChannels := make([]*datapb.VchannelInfo, 0)
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
dmChannels = append(dmChannels, &datapb.VchannelInfo{ dmChannels = append(dmChannels, &datapb.VchannelInfo{
@ -367,16 +417,15 @@ func (suite *CollectionObserverSuite) load(collection int64) {
} }
for _, segment := range suite.segments[collection] { for _, segment := range suite.segments[collection] {
allSegments = append(allSegments, &datapb.SegmentBinlogs{ allSegments[segment.PartitionID] = append(allSegments[segment.PartitionID], &datapb.SegmentBinlogs{
SegmentID: segment.GetID(), SegmentID: segment.GetID(),
InsertChannel: segment.GetInsertChannel(), InsertChannel: segment.GetInsertChannel(),
}) })
} }
partitions := suite.partitions[collection] partitions := suite.partitions[collection]
for _, partition := range partitions { for _, partition := range partitions {
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments[partition], nil)
} }
suite.targetMgr.UpdateCollectionNextTargetWithPartitions(collection, partitions...) suite.targetMgr.UpdateCollectionNextTargetWithPartitions(collection, partitions...)
} }

View File

@ -97,6 +97,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -152,6 +153,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -209,6 +211,8 @@ func (suite *LeaderObserverTestSuite) TestIgnoreBalancedSegment() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -247,6 +251,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -340,6 +345,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))

View File

@ -37,6 +37,7 @@ type checkRequest struct {
type targetUpdateRequest struct { type targetUpdateRequest struct {
CollectionID int64 CollectionID int64
PartitionIDs []int64
Notifier chan error Notifier chan error
ReadyNotifier chan struct{} ReadyNotifier chan struct{}
} }
@ -108,7 +109,7 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID) req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID)
case req := <-ob.updateChan: case req := <-ob.updateChan:
err := ob.updateNextTarget(req.CollectionID) err := ob.updateNextTarget(req.CollectionID, req.PartitionIDs...)
if err != nil { if err != nil {
close(req.ReadyNotifier) close(req.ReadyNotifier)
} else { } else {
@ -148,13 +149,14 @@ func (ob *TargetObserver) check(collectionID int64) {
// UpdateNextTarget updates the next target, // UpdateNextTarget updates the next target,
// returns a channel which will be closed when the next target is ready, // returns a channel which will be closed when the next target is ready,
// or returns error if failed to pull target // or returns error if failed to pull target
func (ob *TargetObserver) UpdateNextTarget(collectionID int64) (chan struct{}, error) { func (ob *TargetObserver) UpdateNextTarget(collectionID int64, partitionIDs ...int64) (chan struct{}, error) {
notifier := make(chan error) notifier := make(chan error)
readyCh := make(chan struct{}) readyCh := make(chan struct{})
defer close(notifier) defer close(notifier)
ob.updateChan <- targetUpdateRequest{ ob.updateChan <- targetUpdateRequest{
CollectionID: collectionID, CollectionID: collectionID,
PartitionIDs: partitionIDs,
Notifier: notifier, Notifier: notifier,
ReadyNotifier: readyCh, ReadyNotifier: readyCh,
} }
@ -208,11 +210,16 @@ func (ob *TargetObserver) isNextTargetExpired(collectionID int64) bool {
return time.Since(ob.nextTargetLastUpdate[collectionID]) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second) return time.Since(ob.nextTargetLastUpdate[collectionID]) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second)
} }
func (ob *TargetObserver) updateNextTarget(collectionID int64) error { func (ob *TargetObserver) updateNextTarget(collectionID int64, partitionIDs ...int64) error {
log := log.With(zap.Int64("collectionID", collectionID)) log := log.With(zap.Int64("collectionID", collectionID), zap.Int64s("partIDs", partitionIDs))
log.Info("observer trigger update next target") log.Info("observer trigger update next target")
err := ob.targetMgr.UpdateCollectionNextTarget(collectionID) var err error
if len(partitionIDs) == 0 {
err = ob.targetMgr.UpdateCollectionNextTarget(collectionID)
} else {
err = ob.targetMgr.UpdateCollectionNextTargetWithPartitions(collectionID, partitionIDs...)
}
if err != nil { if err != nil {
log.Error("failed to update next target for collection", log.Error("failed to update next target for collection",
zap.Error(err)) zap.Error(err))

View File

@ -87,6 +87,8 @@ func (suite *TargetObserverSuite) SetupTest() {
err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1)) err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1))
suite.NoError(err) suite.NoError(err)
err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID))
suite.NoError(err)
replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, 1, meta.DefaultResourceGroupName) replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, 1, meta.DefaultResourceGroupName)
suite.NoError(err) suite.NoError(err)
replicas[0].AddNode(2) replicas[0].AddNode(2)
@ -115,8 +117,8 @@ func (suite *TargetObserverSuite) SetupTest() {
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything).Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil) suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything).Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
} }
func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
@ -158,12 +160,10 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID) suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID)
// Pull next again // Pull next again
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil)
suite.broker.EXPECT(). suite.broker.EXPECT().
GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything). GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything).
Return(suite.nextTargetChannels, suite.nextTargetSegments, nil) Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
suite.broker.EXPECT().
GetPartitions(mock.Anything, mock.Anything).
Return([]int64{suite.partitionID}, nil)
suite.Eventually(func() bool { suite.Eventually(func() bool {
return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 && return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 &&
len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2 len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2

View File

@ -286,8 +286,13 @@ func (s *Server) initMeta() error {
s.store = meta.NewMetaStore(s.kv) s.store = meta.NewMetaStore(s.kv)
s.meta = meta.NewMeta(s.idAllocator, s.store, s.nodeMgr) s.meta = meta.NewMeta(s.idAllocator, s.store, s.nodeMgr)
s.broker = meta.NewCoordinatorBroker(
s.dataCoord,
s.rootCoord,
)
log.Info("recover meta...") log.Info("recover meta...")
err := s.meta.CollectionManager.Recover() err := s.meta.CollectionManager.Recover(s.broker)
if err != nil { if err != nil {
log.Error("failed to recover collections") log.Error("failed to recover collections")
return err return err
@ -295,6 +300,7 @@ func (s *Server) initMeta() error {
collections := s.meta.GetAll() collections := s.meta.GetAll()
log.Info("recovering collections...", zap.Int64s("collections", collections)) log.Info("recovering collections...", zap.Int64s("collections", collections))
metrics.QueryCoordNumCollections.WithLabelValues().Set(float64(len(collections))) metrics.QueryCoordNumCollections.WithLabelValues().Set(float64(len(collections)))
metrics.QueryCoordNumPartitions.WithLabelValues().Set(float64(len(s.meta.GetAllPartitions())))
err = s.meta.ReplicaManager.Recover(collections) err = s.meta.ReplicaManager.Recover(collections)
if err != nil { if err != nil {
@ -313,10 +319,6 @@ func (s *Server) initMeta() error {
ChannelDistManager: meta.NewChannelDistManager(), ChannelDistManager: meta.NewChannelDistManager(),
LeaderViewManager: meta.NewLeaderViewManager(), LeaderViewManager: meta.NewLeaderViewManager(),
} }
s.broker = meta.NewCoordinatorBroker(
s.dataCoord,
s.rootCoord,
)
s.targetMgr = meta.NewTargetManager(s.broker, s.meta) s.targetMgr = meta.NewTargetManager(s.broker, s.meta)
record.Record("Server initMeta") record.Record("Server initMeta")

View File

@ -23,7 +23,6 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -42,6 +41,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/commonpbutil" "github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/merr"
@ -116,6 +116,7 @@ func (suite *ServerSuite) SetupTest() {
ok := suite.waitNodeUp(suite.nodes[i], 5*time.Second) ok := suite.waitNodeUp(suite.nodes[i], 5*time.Second)
suite.Require().True(ok) suite.Require().True(ok)
suite.server.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, suite.nodes[i].ID) suite.server.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, suite.nodes[i].ID)
suite.expectLoadAndReleasePartitions(suite.nodes[i])
} }
suite.loadAll() suite.loadAll()
@ -158,14 +159,15 @@ func (suite *ServerSuite) TestRecoverFailed() {
suite.NoError(err) suite.NoError(err)
broker := meta.NewMockBroker(suite.T()) broker := meta.NewMockBroker(suite.T())
broker.EXPECT().GetPartitions(context.TODO(), int64(1000)).Return(nil, errors.New("CollectionNotExist")) for _, collection := range suite.collections {
broker.EXPECT().GetRecoveryInfo(context.TODO(), int64(1001), mock.Anything).Return(nil, nil, errors.New("CollectionNotExist")) broker.EXPECT().GetPartitions(mock.Anything, collection).Return([]int64{1}, nil)
broker.EXPECT().GetRecoveryInfo(context.TODO(), collection, mock.Anything).Return(nil, nil, errors.New("CollectionNotExist"))
}
suite.server.targetMgr = meta.NewTargetManager(broker, suite.server.meta) suite.server.targetMgr = meta.NewTargetManager(broker, suite.server.meta)
err = suite.server.Start() err = suite.server.Start()
suite.NoError(err) suite.NoError(err)
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.False(suite.server.meta.Exist(collection))
suite.Nil(suite.server.targetMgr.GetDmChannelsByCollection(collection, meta.NextTarget)) suite.Nil(suite.server.targetMgr.GetDmChannelsByCollection(collection, meta.NextTarget))
} }
} }
@ -259,7 +261,6 @@ func (suite *ServerSuite) TestEnableActiveStandby() {
Schema: &schemapb.CollectionSchema{}, Schema: &schemapb.CollectionSchema{},
}, nil).Maybe() }, nil).Maybe()
for _, collection := range suite.collections { for _, collection := range suite.collections {
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
req := &milvuspb.ShowPartitionsRequest{ req := &milvuspb.ShowPartitionsRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
@ -270,9 +271,7 @@ func (suite *ServerSuite) TestEnableActiveStandby() {
Status: merr.Status(nil), Status: merr.Status(nil),
PartitionIDs: suite.partitions[collection], PartitionIDs: suite.partitions[collection],
}, nil).Maybe() }, nil).Maybe()
}
suite.expectGetRecoverInfoByMockDataCoord(collection, mockDataCoord) suite.expectGetRecoverInfoByMockDataCoord(collection, mockDataCoord)
} }
err = suite.server.SetRootCoord(mockRootCoord) err = suite.server.SetRootCoord(mockRootCoord)
suite.NoError(err) suite.NoError(err)
@ -385,6 +384,11 @@ func (suite *ServerSuite) expectGetRecoverInfo(collection int64) {
} }
} }
func (suite *ServerSuite) expectLoadAndReleasePartitions(querynode *mocks.MockQueryNode) {
querynode.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil).Maybe()
querynode.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil).Maybe()
}
func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.DataCoord) { func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.DataCoord) {
var ( var (
vChannels []*datapb.VchannelInfo vChannels []*datapb.VchannelInfo
@ -432,7 +436,7 @@ func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status quer
} }
collection.CollectionLoadInfo.Status = status collection.CollectionLoadInfo.Status = status
suite.server.meta.UpdateCollection(collection) suite.server.meta.UpdateCollection(collection)
} else {
partitions := suite.server.meta.GetPartitionsByCollection(collectionID) partitions := suite.server.meta.GetPartitionsByCollection(collectionID)
for _, partition := range partitions { for _, partition := range partitions {
partition := partition.Clone() partition := partition.Clone()
@ -488,9 +492,7 @@ func (suite *ServerSuite) hackServer() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe() suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe()
for _, collection := range suite.collections { for _, collection := range suite.collections {
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
}
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
} }
log.Debug("server hacked") log.Debug("server hacked")

View File

@ -76,9 +76,6 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
for _, collection := range s.meta.GetAllCollections() { for _, collection := range s.meta.GetAllCollections() {
collectionSet.Insert(collection.GetCollectionID()) collectionSet.Insert(collection.GetCollectionID())
} }
for _, partition := range s.meta.GetAllPartitions() {
collectionSet.Insert(partition.GetCollectionID())
}
isGetAll = true isGetAll = true
} }
collections := collectionSet.Collect() collections := collectionSet.Collect()
@ -92,7 +89,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
for _, collectionID := range collections { for _, collectionID := range collections {
log := log.With(zap.Int64("collectionID", collectionID)) log := log.With(zap.Int64("collectionID", collectionID))
percentage := s.meta.CollectionManager.GetLoadPercentage(collectionID) percentage := s.meta.CollectionManager.GetCollectionLoadPercentage(collectionID)
if percentage < 0 { if percentage < 0 {
if isGetAll { if isGetAll {
// The collection is released during this, // The collection is released during this,
@ -139,68 +136,34 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
} }
defer meta.GlobalFailedLoadCache.TryExpire() defer meta.GlobalFailedLoadCache.TryExpire()
// TODO(yah01): now, for load collection, the percentage of partition is equal to the percentage of collection,
// we can calculates the real percentage of partitions
partitions := req.GetPartitionIDs() partitions := req.GetPartitionIDs()
percentages := make([]int64, 0) percentages := make([]int64, 0)
isReleased := false
switch s.meta.GetLoadType(req.GetCollectionID()) {
case querypb.LoadType_LoadCollection:
percentage := s.meta.GetLoadPercentage(req.GetCollectionID())
if percentage < 0 {
isReleased = true
break
}
if len(partitions) == 0 {
var err error
partitions, err = s.broker.GetPartitions(ctx, req.GetCollectionID())
if err != nil {
msg := "failed to show partitions"
log.Warn(msg, zap.Error(err))
return &querypb.ShowPartitionsResponse{
Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, err),
}, nil
}
}
for range partitions {
percentages = append(percentages, int64(percentage))
}
case querypb.LoadType_LoadPartition:
if len(partitions) == 0 { if len(partitions) == 0 {
partitions = lo.Map(s.meta.GetPartitionsByCollection(req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 { partitions = lo.Map(s.meta.GetPartitionsByCollection(req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID() return partition.GetPartitionID()
}) })
} }
for _, partitionID := range partitions { for _, partitionID := range partitions {
partition := s.meta.GetPartition(partitionID) percentage := s.meta.GetPartitionLoadPercentage(partitionID)
if partition == nil { if percentage < 0 {
isReleased = true
break
}
percentages = append(percentages, int64(partition.LoadPercentage))
}
default:
isReleased = true
}
if isReleased {
err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID()) err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID())
if err != nil { if err != nil {
status := merr.Status(err) status := merr.Status(err)
status.ErrorCode = commonpb.ErrorCode_InsufficientMemoryToLoad status.ErrorCode = commonpb.ErrorCode_InsufficientMemoryToLoad
log.Warn("show partition failed", zap.Error(err))
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
Status: status, Status: status,
}, nil }, nil
} }
msg := fmt.Sprintf("collection %v has not been loaded into QueryNode", req.GetCollectionID()) msg := fmt.Sprintf("partition %d has not been loaded to memory or load failed", partitionID)
log.Warn(msg) log.Warn(msg)
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg),
}, nil }, nil
} }
percentages = append(percentages, int64(percentage))
}
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
Status: merr.Status(nil), Status: merr.Status(nil),
@ -246,7 +209,9 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection
req, req,
s.dist, s.dist,
s.meta, s.meta,
s.cluster,
s.targetMgr, s.targetMgr,
s.targetObserver,
s.broker, s.broker,
s.nodeMgr, s.nodeMgr,
) )
@ -340,7 +305,9 @@ func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions
req, req,
s.dist, s.dist,
s.meta, s.meta,
s.cluster,
s.targetMgr, s.targetMgr,
s.targetObserver,
s.broker, s.broker,
s.nodeMgr, s.nodeMgr,
) )
@ -401,6 +368,7 @@ func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
req, req,
s.dist, s.dist,
s.meta, s.meta,
s.cluster,
s.targetMgr, s.targetMgr,
s.targetObserver, s.targetObserver,
) )
@ -528,6 +496,31 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
}, nil }, nil
} }
func (s *Server) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64("partitionID", req.GetPartitionID()),
)
log.Info("received sync new created partition request")
failedMsg := "failed to sync new created partition"
if s.status.Load() != commonpb.StateCode_Healthy {
log.Warn(failedMsg, zap.Error(ErrNotHealthy))
return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, failedMsg, ErrNotHealthy), nil
}
syncJob := job.NewSyncNewCreatedPartitionJob(ctx, req, s.meta, s.cluster)
s.jobScheduler.Add(syncJob)
err := syncJob.Wait()
if err != nil && !errors.Is(err, job.ErrPartitionNotInTarget) {
log.Warn(failedMsg, zap.Error(err))
return utils.WrapStatus(errCode(err), failedMsg, err), nil
}
return merr.Status(nil), nil
}
// refreshCollection must be called after loading a collection. It looks for new segments that are not loaded yet and // refreshCollection must be called after loading a collection. It looks for new segments that are not loaded yet and
// tries to load them up. It returns when all segments of the given collection are loaded, or when error happens. // tries to load them up. It returns when all segments of the given collection are loaded, or when error happens.
// Note that a collection's loading progress always stays at 100% after a successful load and will not get updated // Note that a collection's loading progress always stays at 100% after a successful load and will not get updated
@ -547,7 +540,7 @@ func (s *Server) refreshCollection(ctx context.Context, collID int64) (*commonpb
} }
// Check that collection is fully loaded. // Check that collection is fully loaded.
if s.meta.CollectionManager.GetLoadPercentage(collID) != 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(collID) != 100 {
errMsg := "a collection must be fully loaded before refreshing" errMsg := "a collection must be fully loaded before refreshing"
log.Warn(errMsg) log.Warn(errMsg)
return &commonpb.Status{ return &commonpb.Status{
@ -601,7 +594,7 @@ func (s *Server) refreshPartitions(ctx context.Context, collID int64, partIDs []
} }
// Check that all partitions are fully loaded. // Check that all partitions are fully loaded.
if s.meta.CollectionManager.GetLoadPercentage(collID) != 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(collID) != 100 {
errMsg := "partitions must be fully loaded before refreshing" errMsg := "partitions must be fully loaded before refreshing"
log.Warn(errMsg) log.Warn(errMsg)
return &commonpb.Status{ return &commonpb.Status{
@ -671,7 +664,7 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
log.Warn(msg, zap.Int("source-nodes-num", len(req.GetSourceNodeIDs()))) log.Warn(msg, zap.Int("source-nodes-num", len(req.GetSourceNodeIDs())))
return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil
} }
if s.meta.CollectionManager.GetLoadPercentage(req.GetCollectionID()) < 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(req.GetCollectionID()) < 100 {
msg := "can't balance segments of not fully loaded collection" msg := "can't balance segments of not fully loaded collection"
log.Warn(msg) log.Warn(msg)
return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil
@ -845,7 +838,7 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade
Status: merr.Status(nil), Status: merr.Status(nil),
} }
if s.meta.CollectionManager.GetLoadPercentage(req.GetCollectionID()) < 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(req.GetCollectionID()) < 100 {
msg := fmt.Sprintf("collection %v is not fully loaded", req.GetCollectionID()) msg := fmt.Sprintf("collection %v is not fully loaded", req.GetCollectionID())
log.Warn(msg) log.Warn(msg)
resp.Status = utils.WrapStatus(commonpb.ErrorCode_NoReplicaAvailable, msg) resp.Status = utils.WrapStatus(commonpb.ErrorCode_NoReplicaAvailable, msg)

View File

@ -142,6 +142,7 @@ func (suite *ServiceSuite) SetupTest() {
suite.dist, suite.dist,
suite.broker, suite.broker,
) )
suite.targetObserver.Start(context.Background())
for _, node := range suite.nodes { for _, node := range suite.nodes {
suite.nodeMgr.Add(session.NewNodeInfo(node, "localhost")) suite.nodeMgr.Add(session.NewNodeInfo(node, "localhost"))
err := suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node) err := suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node)
@ -311,7 +312,6 @@ func (suite *ServiceSuite) TestLoadCollection() {
// Test load all collections // Test load all collections
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
@ -776,6 +776,10 @@ func (suite *ServiceSuite) TestLoadPartition() {
// Test load all partitions // Test load all partitions
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).
Return(append(suite.partitions[collection], 999), nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, int64(999)).
Return(nil, nil, nil)
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
@ -808,6 +812,36 @@ func (suite *ServiceSuite) TestLoadPartition() {
suite.NoError(err) suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode) suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
// Test load with collection loaded
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
// Test load with more partitions
suite.cluster.EXPECT().LoadPartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 999),
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
// Test when server is not healthy // Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing) server.UpdateStateCode(commonpb.StateCode_Initializing)
req = &querypb.LoadPartitionsRequest{ req = &querypb.LoadPartitionsRequest{
@ -836,36 +870,6 @@ func (suite *ServiceSuite) TestLoadPartitionFailed() {
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode) suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error()) suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
} }
// Test load with collection loaded
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
// Test load with more partitions
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 999),
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
} }
func (suite *ServiceSuite) TestReleaseCollection() { func (suite *ServiceSuite) TestReleaseCollection() {
@ -910,6 +914,8 @@ func (suite *ServiceSuite) TestReleasePartition() {
server := suite.server server := suite.server
// Test release all partitions // Test release all partitions
suite.cluster.EXPECT().ReleasePartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
for _, collection := range suite.collections { for _, collection := range suite.collections {
req := &querypb.ReleasePartitionsRequest{ req := &querypb.ReleasePartitionsRequest{
CollectionID: collection, CollectionID: collection,
@ -917,11 +923,7 @@ func (suite *ServiceSuite) TestReleasePartition() {
} }
resp, err := server.ReleasePartitions(ctx, req) resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err) suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
} else {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...) suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
} }
@ -933,11 +935,7 @@ func (suite *ServiceSuite) TestReleasePartition() {
} }
resp, err := server.ReleasePartitions(ctx, req) resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err) suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
} else {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode) suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...) suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
} }
@ -957,7 +955,6 @@ func (suite *ServiceSuite) TestRefreshCollection() {
defer cancel() defer cancel()
server := suite.server server := suite.server
suite.targetObserver.Start(context.Background())
suite.server.collectionObserver.Start(context.Background()) suite.server.collectionObserver.Start(context.Background())
// Test refresh all collections. // Test refresh all collections.
@ -970,7 +967,6 @@ func (suite *ServiceSuite) TestRefreshCollection() {
// Test load all collections // Test load all collections
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
@ -1023,7 +1019,6 @@ func (suite *ServiceSuite) TestRefreshPartitions() {
defer cancel() defer cancel()
server := suite.server server := suite.server
suite.targetObserver.Start(context.Background())
suite.server.collectionObserver.Start(context.Background()) suite.server.collectionObserver.Start(context.Background())
// Test refresh all partitions. // Test refresh all partitions.
@ -1636,8 +1631,6 @@ func (suite *ServiceSuite) loadAll() {
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
CollectionID: collection, CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection], ReplicaNumber: suite.replicaNumber[collection],
@ -1647,7 +1640,9 @@ func (suite *ServiceSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -1669,7 +1664,9 @@ func (suite *ServiceSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -1741,6 +1738,7 @@ func (suite *ServiceSuite) assertSegments(collection int64, segments []*querypb.
} }
func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) { func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
vChannels := []*datapb.VchannelInfo{} vChannels := []*datapb.VchannelInfo{}
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
vChannels = append(vChannels, &datapb.VchannelInfo{ vChannels = append(vChannels, &datapb.VchannelInfo{
@ -1848,7 +1846,7 @@ func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status que
} }
collection.CollectionLoadInfo.Status = status collection.CollectionLoadInfo.Status = status
suite.meta.UpdateCollection(collection) suite.meta.UpdateCollection(collection)
} else {
partitions := suite.meta.GetPartitionsByCollection(collectionID) partitions := suite.meta.GetPartitionsByCollection(collectionID)
for _, partition := range partitions { for _, partition := range partitions {
partition := partition.Clone() partition := partition.Clone()
@ -1869,6 +1867,10 @@ func (suite *ServiceSuite) fetchHeartbeats(time time.Time) {
} }
} }
func (suite *ServiceSuite) TearDownTest() {
suite.targetObserver.Stop()
}
func TestService(t *testing.T) { func TestService(t *testing.T) {
suite.Run(t, new(ServiceSuite)) suite.Run(t, new(ServiceSuite))
} }

View File

@ -53,6 +53,8 @@ type Cluster interface {
UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error)
LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error)
ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
LoadPartitions(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error)
ReleasePartitions(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)
GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error)
GetMetrics(ctx context.Context, nodeID int64, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) GetMetrics(ctx context.Context, nodeID int64, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)
SyncDistribution(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) SyncDistribution(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) (*commonpb.Status, error)
@ -174,6 +176,34 @@ func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *q
return status, err return status, err
} }
func (c *QueryCluster) LoadPartitions(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.LoadPartitionsRequest)
req.Base.TargetID = nodeID
status, err = cli.LoadPartitions(ctx, req)
})
if err1 != nil {
return nil, err1
}
return status, err
}
func (c *QueryCluster) ReleasePartitions(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.ReleasePartitionsRequest)
req.Base.TargetID = nodeID
status, err = cli.ReleasePartitions(ctx, req)
})
if err1 != nil {
return nil, err1
}
return status, err
}
func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
var resp *querypb.GetDataDistributionResponse var resp *querypb.GetDataDistributionResponse
var err error var err error

View File

@ -124,6 +124,14 @@ func (suite *ClusterTestSuite) createDefaultMockServer() querypb.QueryNodeServer
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"), mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"),
).Maybe().Return(succStatus, nil) ).Maybe().Return(succStatus, nil)
svr.EXPECT().LoadPartitions(
mock.Anything,
mock.AnythingOfType("*querypb.LoadPartitionsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().ReleasePartitions(
mock.Anything,
mock.AnythingOfType("*querypb.ReleasePartitionsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().GetDataDistribution( svr.EXPECT().GetDataDistribution(
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.GetDataDistributionRequest"), mock.AnythingOfType("*querypb.GetDataDistributionRequest"),
@ -169,6 +177,14 @@ func (suite *ClusterTestSuite) createFailedMockServer() querypb.QueryNodeServer
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"), mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"),
).Maybe().Return(failStatus, nil) ).Maybe().Return(failStatus, nil)
svr.EXPECT().LoadPartitions(
mock.Anything,
mock.AnythingOfType("*querypb.LoadPartitionsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().ReleasePartitions(
mock.Anything,
mock.AnythingOfType("*querypb.ReleasePartitionsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().GetDataDistribution( svr.EXPECT().GetDataDistribution(
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.GetDataDistributionRequest"), mock.AnythingOfType("*querypb.GetDataDistributionRequest"),
@ -284,6 +300,45 @@ func (suite *ClusterTestSuite) TestReleaseSegments() {
}, status) }, status)
} }
func (suite *ClusterTestSuite) TestLoadAndReleasePartitions() {
ctx := context.TODO()
status, err := suite.cluster.LoadPartitions(ctx, 0, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, status)
status, err = suite.cluster.LoadPartitions(ctx, 1, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected error",
}, status)
status, err = suite.cluster.ReleasePartitions(ctx, 0, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, status)
status, err = suite.cluster.ReleasePartitions(ctx, 1, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected error",
}, status)
}
func (suite *ClusterTestSuite) TestGetDataDistribution() { func (suite *ClusterTestSuite) TestGetDataDistribution() {
ctx := context.TODO() ctx := context.TODO()
resp, err := suite.cluster.GetDataDistribution(ctx, 0, &querypb.GetDataDistributionRequest{ resp, err := suite.cluster.GetDataDistribution(ctx, 0, &querypb.GetDataDistributionRequest{

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package session package session
@ -170,6 +170,54 @@ func (_c *MockCluster_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse,
return _c return _c
} }
// LoadPartitions provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) LoadPartitions(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, int64, *querypb.LoadPartitionsRequest) *commonpb.Status); ok {
r0 = rf(ctx, nodeID, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, *querypb.LoadPartitionsRequest) error); ok {
r1 = rf(ctx, nodeID, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCluster_LoadPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadPartitions'
type MockCluster_LoadPartitions_Call struct {
*mock.Call
}
// LoadPartitions is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
// - req *querypb.LoadPartitionsRequest
func (_e *MockCluster_Expecter) LoadPartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadPartitions_Call {
return &MockCluster_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, nodeID, req)}
}
func (_c *MockCluster_LoadPartitions_Call) Run(run func(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest)) *MockCluster_LoadPartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(*querypb.LoadPartitionsRequest))
})
return _c
}
func (_c *MockCluster_LoadPartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockCluster_LoadPartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// LoadSegments provides a mock function with given fields: ctx, nodeID, req // LoadSegments provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { func (_m *MockCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req) ret := _m.Called(ctx, nodeID, req)
@ -218,6 +266,54 @@ func (_c *MockCluster_LoadSegments_Call) Return(_a0 *commonpb.Status, _a1 error)
return _c return _c
} }
// ReleasePartitions provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) ReleasePartitions(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, int64, *querypb.ReleasePartitionsRequest) *commonpb.Status); ok {
r0 = rf(ctx, nodeID, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, *querypb.ReleasePartitionsRequest) error); ok {
r1 = rf(ctx, nodeID, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCluster_ReleasePartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleasePartitions'
type MockCluster_ReleasePartitions_Call struct {
*mock.Call
}
// ReleasePartitions is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
// - req *querypb.ReleasePartitionsRequest
func (_e *MockCluster_Expecter) ReleasePartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleasePartitions_Call {
return &MockCluster_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, nodeID, req)}
}
func (_c *MockCluster_ReleasePartitions_Call) Run(run func(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest)) *MockCluster_ReleasePartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(*querypb.ReleasePartitionsRequest))
})
return _c
}
func (_c *MockCluster_ReleasePartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockCluster_ReleasePartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// ReleaseSegments provides a mock function with given fields: ctx, nodeID, req // ReleaseSegments provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { func (_m *MockCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req) ret := _m.Called(ctx, nodeID, req)

View File

@ -247,7 +247,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
log.Warn("failed to get schema of collection", zap.Error(err)) log.Warn("failed to get schema of collection", zap.Error(err))
return err return err
} }
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, ex.broker, task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID())
if err != nil { if err != nil {
log.Warn("failed to get partitions of collection", zap.Error(err)) log.Warn("failed to get partitions of collection", zap.Error(err))
return err return err
@ -388,7 +388,7 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error {
log.Warn("failed to get schema of collection") log.Warn("failed to get schema of collection")
return err return err
} }
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, ex.broker, task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID())
if err != nil { if err != nil {
log.Warn("failed to get partitions of collection") log.Warn("failed to get partitions of collection")
return err return err

View File

@ -185,8 +185,6 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
Return(&schemapb.CollectionSchema{ Return(&schemapb.CollectionSchema{
Name: "TestSubscribeChannelTask", Name: "TestSubscribeChannelTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).
Return([]int64{100, 101}, nil)
channels := make([]*datapb.VchannelInfo, 0, len(suite.subChannels)) channels := make([]*datapb.VchannelInfo, 0, len(suite.subChannels))
for _, channel := range suite.subChannels { for _, channel := range suite.subChannels {
channels = append(channels, &datapb.VchannelInfo{ channels = append(channels, &datapb.VchannelInfo{
@ -234,6 +232,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0) suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0)
@ -293,6 +292,7 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
@ -333,7 +333,6 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestLoadSegmentTask", Name: "TestLoadSegmentTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -374,6 +373,7 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
@ -417,7 +417,6 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestLoadSegmentTask", Name: "TestLoadSegmentTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -455,6 +454,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
@ -610,7 +610,6 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestMoveSegmentTask", Name: "TestMoveSegmentTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.moveSegments { for _, segment := range suite.moveSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -665,6 +664,7 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
suite.target.UpdateCollectionCurrentTarget(suite.collection, int64(1)) suite.target.UpdateCollectionCurrentTarget(suite.collection, int64(1))
@ -709,7 +709,6 @@ func (suite *TaskSuite) TestTaskCanceled() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestSubscribeChannelTask", Name: "TestSubscribeChannelTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -752,6 +751,7 @@ func (suite *TaskSuite) TestTaskCanceled() {
} }
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{partition}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, partition) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, partition)
@ -787,7 +787,6 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestSegmentTaskStale", Name: "TestSegmentTaskStale",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -829,6 +828,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
@ -856,6 +856,10 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
InsertChannel: channel.GetChannelName(), InsertChannel: channel.GetChannelName(),
}) })
} }
bakExpectations := suite.broker.ExpectedCalls
suite.broker.AssertExpectations(suite.T())
suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0]
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{2}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(2)).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(2)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(2)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(2))
suite.dispatchAndWait(targetNode) suite.dispatchAndWait(targetNode)
@ -870,6 +874,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.NoError(task.Err()) suite.NoError(task.Err())
} }
} }
suite.broker.ExpectedCalls = bakExpectations
} }
func (suite *TaskSuite) TestChannelTaskReplace() { func (suite *TaskSuite) TestChannelTaskReplace() {
@ -1060,6 +1065,7 @@ func (suite *TaskSuite) TestNoExecutor() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)

View File

@ -17,7 +17,6 @@
package utils package utils
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"sort" "sort"
@ -51,19 +50,16 @@ func GetReplicaNodesInfo(replicaMgr *meta.ReplicaManager, nodeMgr *session.NodeM
return nodes return nodes
} }
func GetPartitions(collectionMgr *meta.CollectionManager, broker meta.Broker, collectionID int64) ([]int64, error) { func GetPartitions(collectionMgr *meta.CollectionManager, collectionID int64) ([]int64, error) {
collection := collectionMgr.GetCollection(collectionID) collection := collectionMgr.GetCollection(collectionID)
if collection != nil { if collection != nil {
partitions, err := broker.GetPartitions(context.Background(), collectionID)
return partitions, err
}
partitions := collectionMgr.GetPartitionsByCollection(collectionID) partitions := collectionMgr.GetPartitionsByCollection(collectionID)
if partitions != nil { if partitions != nil {
return lo.Map(partitions, func(partition *meta.Partition, i int) int64 { return lo.Map(partitions, func(partition *meta.Partition, i int) int64 {
return partition.PartitionID return partition.PartitionID
}), nil }), nil
} }
}
// todo(yah01): replace this error with a defined error // todo(yah01): replace this error with a defined error
return nil, fmt.Errorf("collection/partition not loaded") return nil, fmt.Errorf("collection/partition not loaded")

View File

@ -248,7 +248,7 @@ func TestFlowGraphDeleteNode_operate(t *testing.T) {
}, },
} }
msg := []flowgraph.Msg{&dMsg} msg := []flowgraph.Msg{&dMsg}
assert.Panics(t, func() { deleteNode.Operate(msg) }) deleteNode.Operate(msg)
}) })
t.Run("test partition not exist", func(t *testing.T) { t.Run("test partition not exist", func(t *testing.T) {

View File

@ -139,12 +139,12 @@ func (fddNode *filterDeleteNode) filterInvalidDeleteMessage(msg *msgstream.Delet
return nil, nil return nil, nil
} }
if loadType == loadTypePartition { //if loadType == loadTypePartition {
if !fddNode.metaReplica.hasPartition(msg.PartitionID) { // if !fddNode.metaReplica.hasPartition(msg.PartitionID) {
// filter out msg which not belongs to the loaded partitions // // filter out msg which not belongs to the loaded partitions
return nil, nil // return nil, nil
} // }
} //}
return msg, nil return msg, nil
} }

View File

@ -89,7 +89,7 @@ func TestFlowGraphFilterDeleteNode_filterInvalidDeleteMessage(t *testing.T) {
res, err := fg.filterInvalidDeleteMessage(msg, loadTypePartition) res, err := fg.filterInvalidDeleteMessage(msg, loadTypePartition)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, res) assert.NotNil(t, res)
}) })
} }

View File

@ -162,12 +162,12 @@ func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg
return nil, nil return nil, nil
} }
if loadType == loadTypePartition { //if loadType == loadTypePartition {
if !fdmNode.metaReplica.hasPartition(msg.PartitionID) { // if !fdmNode.metaReplica.hasPartition(msg.PartitionID) {
// filter out msg which not belongs to the loaded partitions // // filter out msg which not belongs to the loaded partitions
return nil, nil // return nil, nil
} // }
} //}
return msg, nil return msg, nil
} }
@ -198,12 +198,12 @@ func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg
return nil, nil return nil, nil
} }
if loadType == loadTypePartition { //if loadType == loadTypePartition {
if !fdmNode.metaReplica.hasPartition(msg.PartitionID) { // if !fdmNode.metaReplica.hasPartition(msg.PartitionID) {
// filter out msg which not belongs to the loaded partitions // // filter out msg which not belongs to the loaded partitions
return nil, nil // return nil, nil
} // }
} //}
// Check if the segment is in excluded segments, // Check if the segment is in excluded segments,
// messages after seekPosition may contain the redundant data from flushed slice of segment, // messages after seekPosition may contain the redundant data from flushed slice of segment,

View File

@ -71,18 +71,6 @@ func TestFlowGraphFilterDmNode_filterInvalidInsertMessage(t *testing.T) {
fg.collectionID = defaultCollectionID fg.collectionID = defaultCollectionID
}) })
t.Run("test no partition", func(t *testing.T) {
msg, err := genSimpleInsertMsg(schema, defaultMsgLength)
assert.NoError(t, err)
msg.PartitionID = UniqueID(1000)
fg, err := getFilterDMNode()
assert.NoError(t, err)
res, err := fg.filterInvalidInsertMessage(msg, loadTypePartition)
assert.NoError(t, err)
assert.Nil(t, res)
})
t.Run("test not target collection", func(t *testing.T) { t.Run("test not target collection", func(t *testing.T) {
msg, err := genSimpleInsertMsg(schema, defaultMsgLength) msg, err := genSimpleInsertMsg(schema, defaultMsgLength)
assert.NoError(t, err) assert.NoError(t, err)
@ -162,17 +150,6 @@ func TestFlowGraphFilterDmNode_filterInvalidDeleteMessage(t *testing.T) {
assert.NotNil(t, res) assert.NotNil(t, res)
}) })
t.Run("test delete no partition", func(t *testing.T) {
msg := genDeleteMsg(defaultCollectionID, schemapb.DataType_Int64, defaultDelLength)
msg.PartitionID = UniqueID(1000)
fg, err := getFilterDMNode()
assert.NoError(t, err)
res, err := fg.filterInvalidDeleteMessage(msg, loadTypePartition)
assert.NoError(t, err)
assert.Nil(t, res)
})
t.Run("test delete not target collection", func(t *testing.T) { t.Run("test delete not target collection", func(t *testing.T) {
msg := genDeleteMsg(defaultCollectionID, schemapb.DataType_Int64, defaultDelLength) msg := genDeleteMsg(defaultCollectionID, schemapb.DataType_Int64, defaultDelLength)
fg, err := getFilterDMNode() fg, err := getFilterDMNode()

View File

@ -314,15 +314,6 @@ func processDeleteMessages(replica ReplicaInterface, segType segmentType, msg *m
var err error var err error
if msg.PartitionID != -1 { if msg.PartitionID != -1 {
partitionIDs = []UniqueID{msg.GetPartitionID()} partitionIDs = []UniqueID{msg.GetPartitionID()}
} else {
partitionIDs, err = replica.getPartitionIDs(msg.GetCollectionID())
if err != nil {
log.Warn("the collection has been released, ignore it",
zap.Int64("collectionID", msg.GetCollectionID()),
zap.Error(err),
)
return err
}
} }
var resultSegmentIDs []UniqueID var resultSegmentIDs []UniqueID
resultSegmentIDs, err = replica.getSegmentIDsByVChannel(partitionIDs, vchannelName, segType) resultSegmentIDs, err = replica.getSegmentIDsByVChannel(partitionIDs, vchannelName, segType)

View File

@ -288,7 +288,7 @@ func TestFlowGraphInsertNode_operate(t *testing.T) {
}, },
} }
msg := []flowgraph.Msg{&iMsg} msg := []flowgraph.Msg{&iMsg}
assert.Panics(t, func() { insertNode.Operate(msg) }) insertNode.Operate(msg)
}) })
t.Run("test partition not exist", func(t *testing.T) { t.Run("test partition not exist", func(t *testing.T) {

View File

@ -566,10 +566,10 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
return status, nil return status, nil
} }
// ReleasePartitions clears all data related to this partition on the querynode func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { nodeID := node.session.ServerID
if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) { if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID) err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -578,35 +578,58 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
} }
defer node.lifetime.Done() defer node.lifetime.Done()
dct := &releasePartitionsTask{ // check target matches
baseTask: baseTask{ if req.GetBase().GetTargetID() != nodeID {
ctx: ctx, status := &commonpb.Status{
done: make(chan error), ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
}, Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
req: in, }
node: node, return status, nil
} }
err := node.scheduler.queue.Enqueue(dct) log.Ctx(ctx).With(zap.Int64("colID", req.GetCollectionID()), zap.Int64s("partIDs", req.GetPartitionIDs()))
log.Info("loading partitions")
for _, part := range req.GetPartitionIDs() {
err := node.metaReplica.addPartition(req.GetCollectionID(), part)
if err != nil { if err != nil {
log.Warn(err.Error())
}
}
log.Info("load partitions done")
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
return status, nil
}
// ReleasePartitions clears all data related to this partition on the querynode
func (node *QueryNode) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
nodeID := node.session.ServerID
if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error())
return status, nil return status, nil
} }
log.Info("releasePartitionsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("partitionIDs", in.PartitionIDs)) defer node.lifetime.Done()
func() { // check target matches
err = dct.WaitToFinish() if req.GetBase().GetTargetID() != nodeID {
if err != nil { status := &commonpb.Status{
log.Warn(err.Error()) ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
return Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
}
return status, nil
} }
log.Info("releasePartitionsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("partitionIDs", in.PartitionIDs))
}()
log.Ctx(ctx).With(zap.Int64("colID", req.GetCollectionID()), zap.Int64s("partIDs", req.GetPartitionIDs()))
log.Info("releasing partitions")
for _, part := range req.GetPartitionIDs() {
node.metaReplica.removePartition(part)
}
log.Info("release partitions done")
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
} }

View File

@ -23,6 +23,10 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
@ -33,10 +37,8 @@ import (
"github.com/milvus-io/milvus/internal/util/conc" "github.com/milvus-io/milvus/internal/util/conc"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
) )
func TestImpl_GetComponentStates(t *testing.T) { func TestImpl_GetComponentStates(t *testing.T) {
@ -360,6 +362,36 @@ func TestImpl_ReleaseCollection(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
} }
func TestImpl_LoadPartitions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
TargetID: paramtable.GetNodeID(),
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
status, err := node.LoadPartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = node.LoadPartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
node.UpdateStateCode(commonpb.StateCode_Healthy)
req.Base.TargetID = -1
status, err = node.LoadPartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, status.ErrorCode)
}
func TestImpl_ReleasePartitions(t *testing.T) { func TestImpl_ReleasePartitions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -370,6 +402,7 @@ func TestImpl_ReleasePartitions(t *testing.T) {
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels, MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(), MsgID: rand.Int63(),
TargetID: paramtable.GetNodeID(),
}, },
NodeID: 0, NodeID: 0,
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
@ -384,6 +417,12 @@ func TestImpl_ReleasePartitions(t *testing.T) {
status, err = node.ReleasePartitions(ctx, req) status, err = node.ReleasePartitions(ctx, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
node.UpdateStateCode(commonpb.StateCode_Healthy)
req.Base.TargetID = -1
status, err = node.ReleasePartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, status.ErrorCode)
} }
func TestImpl_GetSegmentInfo(t *testing.T) { func TestImpl_GetSegmentInfo(t *testing.T) {

View File

@ -458,14 +458,14 @@ func (replica *metaReplica) removePartitionPrivate(partitionID UniqueID) error {
} }
// delete segments // delete segments
ids, _ := partition.getSegmentIDs(segmentTypeGrowing) //ids, _ := partition.getSegmentIDs(segmentTypeGrowing)
for _, segmentID := range ids { //for _, segmentID := range ids {
replica.removeSegmentPrivate(segmentID, segmentTypeGrowing) // replica.removeSegmentPrivate(segmentID, segmentTypeGrowing)
} //}
ids, _ = partition.getSegmentIDs(segmentTypeSealed) //ids, _ = partition.getSegmentIDs(segmentTypeSealed)
for _, segmentID := range ids { //for _, segmentID := range ids {
replica.removeSegmentPrivate(segmentID, segmentTypeSealed) // replica.removeSegmentPrivate(segmentID, segmentTypeSealed)
} //}
collection.removePartitionID(partitionID) collection.removePartitionID(partitionID)
delete(replica.partitions, partitionID) delete(replica.partitions, partitionID)
@ -589,10 +589,6 @@ func (replica *metaReplica) addSegment(segmentID UniqueID, partitionID UniqueID,
// addSegmentPrivate is private function in collectionReplica, to add a new segment to collectionReplica // addSegmentPrivate is private function in collectionReplica, to add a new segment to collectionReplica
func (replica *metaReplica) addSegmentPrivate(segment *Segment) error { func (replica *metaReplica) addSegmentPrivate(segment *Segment) error {
segID := segment.segmentID segID := segment.segmentID
partition, err := replica.getPartitionByIDPrivate(segment.partitionID)
if err != nil {
return err
}
segType := segment.getType() segType := segment.getType()
ok, err := replica.hasSegmentPrivate(segID, segType) ok, err := replica.hasSegmentPrivate(segID, segType)
@ -603,12 +599,16 @@ func (replica *metaReplica) addSegmentPrivate(segment *Segment) error {
return fmt.Errorf("segment has been existed, "+ return fmt.Errorf("segment has been existed, "+
"segmentID = %d, collectionID = %d, segmentType = %s", segID, segment.collectionID, segType.String()) "segmentID = %d, collectionID = %d, segmentType = %s", segID, segment.collectionID, segType.String())
} }
partition.addSegmentID(segID, segType)
switch segType { switch segType {
case segmentTypeGrowing: case segmentTypeGrowing:
replica.growingSegments[segID] = segment replica.growingSegments[segID] = segment
case segmentTypeSealed: case segmentTypeSealed:
partition, err := replica.getPartitionByIDPrivate(segment.partitionID)
if err != nil {
return err
}
partition.addSegmentID(segID, segType)
replica.sealedSegments[segID] = segment replica.sealedSegments[segID] = segment
default: default:
return fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segID, segType.String()) return fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segID, segType.String())

View File

@ -200,6 +200,7 @@ func TestStreaming_search(t *testing.T) {
collection, err := streaming.getCollectionByID(defaultCollectionID) collection, err := streaming.getCollectionByID(defaultCollectionID)
assert.NoError(t, err) assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -76,6 +76,10 @@ func TestHistorical_statistic(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
collection, err := his.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -153,6 +157,10 @@ func TestStreaming_statistics(t *testing.T) {
streaming, err := genSimpleReplicaWithGrowingSegment() streaming, err := genSimpleReplicaWithGrowingSegment()
assert.NoError(t, err) assert.NoError(t, err)
collection, err := streaming.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -86,6 +86,9 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
t.Run("test validate after partition release", func(t *testing.T) { t.Run("test validate after partition release", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
collection, err := his.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})

View File

@ -48,6 +48,8 @@ type watchInfo struct {
// Broker communicates with other components. // Broker communicates with other components.
type Broker interface { type Broker interface {
ReleaseCollection(ctx context.Context, collectionID UniqueID) error ReleaseCollection(ctx context.Context, collectionID UniqueID) error
ReleasePartitions(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error
SyncNewCreatedPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error
GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error)
WatchChannels(ctx context.Context, info *watchInfo) error WatchChannels(ctx context.Context, info *watchInfo) error
@ -93,6 +95,49 @@ func (b *ServerBroker) ReleaseCollection(ctx context.Context, collectionID Uniqu
return nil return nil
} }
func (b *ServerBroker) ReleasePartitions(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
if len(partitionIDs) == 0 {
return nil
}
log := log.Ctx(ctx).With(zap.Int64("collection", collectionID), zap.Int64s("partitionIDs", partitionIDs))
log.Info("releasing partitions")
resp, err := b.s.queryCoord.ReleasePartitions(ctx, &querypb.ReleasePartitionsRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ReleasePartitions)),
CollectionID: collectionID,
PartitionIDs: partitionIDs,
})
if err != nil {
return err
}
if resp.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("release partition failed, reason: %s", resp.GetReason())
}
log.Info("release partitions done")
return nil
}
func (b *ServerBroker) SyncNewCreatedPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error {
log := log.Ctx(ctx).With(zap.Int64("collection", collectionID), zap.Int64("partitionID", partitionID))
log.Info("begin to sync new partition")
resp, err := b.s.queryCoord.SyncNewCreatedPartition(ctx, &querypb.SyncNewCreatedPartitionRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ReleasePartitions)),
CollectionID: collectionID,
PartitionID: partitionID,
})
if err != nil {
return err
}
if resp.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("sync new partition failed, reason: %s", resp.GetReason())
}
log.Info("sync new partition done")
return nil
}
func (b *ServerBroker) GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) { func (b *ServerBroker) GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) {
resp, err := b.s.queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{ resp, err := b.s.queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(

View File

@ -73,7 +73,7 @@ func (t *createPartitionTask) Execute(ctx context.Context) error {
PartitionCreatedTimestamp: t.GetTs(), PartitionCreatedTimestamp: t.GetTs(),
Extra: nil, Extra: nil,
CollectionID: t.collMeta.CollectionID, CollectionID: t.collMeta.CollectionID,
State: pb.PartitionState_PartitionCreated, State: pb.PartitionState_PartitionCreating,
} }
undoTask := newBaseUndoTask(t.core.stepExecutor) undoTask := newBaseUndoTask(t.core.stepExecutor)
@ -88,5 +88,23 @@ func (t *createPartitionTask) Execute(ctx context.Context) error {
partition: partition, partition: partition,
}, &nullStep{}) // adding partition is atomic enough. }, &nullStep{}) // adding partition is atomic enough.
undoTask.AddStep(&syncNewCreatedPartitionStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionID: partID,
}, &releasePartitionsStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionIDs: []int64{partID},
})
undoTask.AddStep(&changePartitionStateStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionID: partID,
state: pb.PartitionState_PartitionCreated,
ts: t.GetTs(),
}, &nullStep{})
return undoTask.Execute(ctx) return undoTask.Execute(ctx)
} }

View File

@ -20,14 +20,13 @@ import (
"context" "context"
"testing" "testing"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
) )
func Test_createPartitionTask_Prepare(t *testing.T) { func Test_createPartitionTask_Prepare(t *testing.T) {
@ -147,7 +146,14 @@ func Test_createPartitionTask_Execute(t *testing.T) {
meta.AddPartitionFunc = func(ctx context.Context, partition *model.Partition) error { meta.AddPartitionFunc = func(ctx context.Context, partition *model.Partition) error {
return nil return nil
} }
core := newTestCore(withValidIDAllocator(), withValidProxyManager(), withMeta(meta)) meta.ChangePartitionStateFunc = func(ctx context.Context, collectionID UniqueID, partitionID UniqueID, state etcdpb.PartitionState, ts Timestamp) error {
return nil
}
b := newMockBroker()
b.SyncNewCreatedPartitionFunc = func(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error {
return nil
}
core := newTestCore(withValidIDAllocator(), withValidProxyManager(), withMeta(meta), withBroker(b))
task := &createPartitionTask{ task := &createPartitionTask{
baseTask: baseTask{core: core}, baseTask: baseTask{core: core},
collMeta: coll, collMeta: coll,

View File

@ -85,7 +85,12 @@ func (t *dropPartitionTask) Execute(ctx context.Context) error {
ts: t.GetTs(), ts: t.GetTs(),
}) })
// TODO: release partition when query coord is ready. redoTask.AddAsyncStep(&releasePartitionsStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionIDs: []int64{partID},
})
redoTask.AddAsyncStep(&deletePartitionDataStep{ redoTask.AddAsyncStep(&deletePartitionDataStep{
baseStep: baseStep{core: t.core}, baseStep: baseStep{core: t.core},
pchans: t.collMeta.PhysicalChannelNames, pchans: t.collMeta.PhysicalChannelNames,

View File

@ -177,6 +177,9 @@ func Test_dropPartitionTask_Execute(t *testing.T) {
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error { broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return nil return nil
} }
broker.ReleasePartitionsFunc = func(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
return nil
}
core := newTestCore( core := newTestCore(
withValidProxyManager(), withValidProxyManager(),

View File

@ -512,7 +512,7 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio
if !ok || !coll.Available() { if !ok || !coll.Available() {
return fmt.Errorf("collection not exists: %d", partition.CollectionID) return fmt.Errorf("collection not exists: %d", partition.CollectionID)
} }
if partition.State != pb.PartitionState_PartitionCreated { if partition.State != pb.PartitionState_PartitionCreating {
return fmt.Errorf("partition state is not created, collection: %d, partition: %d, state: %s", partition.CollectionID, partition.PartitionID, partition.State) return fmt.Errorf("partition state is not created, collection: %d, partition: %d, state: %s", partition.CollectionID, partition.PartitionID, partition.State)
} }
if err := mt.catalog.CreatePartition(ctx, partition, partition.PartitionCreatedTimestamp); err != nil { if err := mt.catalog.CreatePartition(ctx, partition, partition.PartitionCreatedTimestamp); err != nil {

View File

@ -892,7 +892,7 @@ func TestMetaTable_AddPartition(t *testing.T) {
100: {Name: "test", CollectionID: 100}, 100: {Name: "test", CollectionID: 100},
}, },
} }
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreated}) err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreating})
assert.Error(t, err) assert.Error(t, err)
}) })
@ -909,7 +909,7 @@ func TestMetaTable_AddPartition(t *testing.T) {
100: {Name: "test", CollectionID: 100}, 100: {Name: "test", CollectionID: 100},
}, },
} }
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreated}) err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreating})
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }

View File

@ -500,12 +500,20 @@ func withValidQueryCoord() Opt {
succStatus(), nil, succStatus(), nil,
) )
qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(
succStatus(), nil,
)
qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return( qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return(
&querypb.GetSegmentInfoResponse{ &querypb.GetSegmentInfoResponse{
Status: succStatus(), Status: succStatus(),
}, nil, }, nil,
) )
qc.EXPECT().SyncNewCreatedPartition(mock.Anything, mock.Anything).Return(
succStatus(), nil,
)
return withQueryCoord(qc) return withQueryCoord(qc)
} }
@ -780,6 +788,8 @@ type mockBroker struct {
Broker Broker
ReleaseCollectionFunc func(ctx context.Context, collectionID UniqueID) error ReleaseCollectionFunc func(ctx context.Context, collectionID UniqueID) error
ReleasePartitionsFunc func(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error
SyncNewCreatedPartitionFunc func(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error
GetQuerySegmentInfoFunc func(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) GetQuerySegmentInfoFunc func(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error)
WatchChannelsFunc func(ctx context.Context, info *watchInfo) error WatchChannelsFunc func(ctx context.Context, info *watchInfo) error
@ -814,6 +824,14 @@ func (b mockBroker) ReleaseCollection(ctx context.Context, collectionID UniqueID
return b.ReleaseCollectionFunc(ctx, collectionID) return b.ReleaseCollectionFunc(ctx, collectionID)
} }
func (b mockBroker) ReleasePartitions(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
return b.ReleasePartitionsFunc(ctx, collectionID)
}
func (b mockBroker) SyncNewCreatedPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error {
return b.SyncNewCreatedPartitionFunc(ctx, collectionID, partitionID)
}
func (b mockBroker) DropCollectionIndex(ctx context.Context, collID UniqueID, partIDs []UniqueID) error { func (b mockBroker) DropCollectionIndex(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return b.DropCollectionIndexFunc(ctx, collID, partIDs) return b.DropCollectionIndexFunc(ctx, collID, partIDs)
} }

View File

@ -273,6 +273,44 @@ func (s *releaseCollectionStep) Weight() stepPriority {
return stepPriorityUrgent return stepPriorityUrgent
} }
type releasePartitionsStep struct {
baseStep
collectionID UniqueID
partitionIDs []UniqueID
}
func (s *releasePartitionsStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.ReleasePartitions(ctx, s.collectionID, s.partitionIDs...)
return nil, err
}
func (s *releasePartitionsStep) Desc() string {
return fmt.Sprintf("release partitions, collectionID=%d, partitionIDs=%v", s.collectionID, s.partitionIDs)
}
func (s *releasePartitionsStep) Weight() stepPriority {
return stepPriorityUrgent
}
type syncNewCreatedPartitionStep struct {
baseStep
collectionID UniqueID
partitionID UniqueID
}
func (s *syncNewCreatedPartitionStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.SyncNewCreatedPartition(ctx, s.collectionID, s.partitionID)
return nil, err
}
func (s *syncNewCreatedPartitionStep) Desc() string {
return fmt.Sprintf("sync new partition, collectionID=%d, partitionID=%d", s.partitionID, s.partitionID)
}
func (s *syncNewCreatedPartitionStep) Weight() stepPriority {
return stepPriorityUrgent
}
type dropIndexStep struct { type dropIndexStep struct {
baseStep baseStep
collID UniqueID collID UniqueID

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package types package types
@ -1316,6 +1316,53 @@ func (_c *MockQueryCoord_Stop_Call) Return(_a0 error) *MockQueryCoord_Stop_Call
return _c return _c
} }
// SyncNewCreatedPartition provides a mock function with given fields: ctx, req
func (_m *MockQueryCoord) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest) *commonpb.Status); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_SyncNewCreatedPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncNewCreatedPartition'
type MockQueryCoord_SyncNewCreatedPartition_Call struct {
*mock.Call
}
// SyncNewCreatedPartition is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.SyncNewCreatedPartitionRequest
func (_e *MockQueryCoord_Expecter) SyncNewCreatedPartition(ctx interface{}, req interface{}) *MockQueryCoord_SyncNewCreatedPartition_Call {
return &MockQueryCoord_SyncNewCreatedPartition_Call{Call: _e.mock.On("SyncNewCreatedPartition", ctx, req)}
}
func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) Run(run func(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest)) *MockQueryCoord_SyncNewCreatedPartition_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.SyncNewCreatedPartitionRequest))
})
return _c
}
func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SyncNewCreatedPartition_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// TransferNode provides a mock function with given fields: ctx, req // TransferNode provides a mock function with given fields: ctx, req
func (_m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { func (_m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)

View File

@ -1330,6 +1330,7 @@ type QueryNode interface {
// All the sealed segments are loaded. // All the sealed segments are loaded.
LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error)
ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)
LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error)
ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)
ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)
@ -1374,6 +1375,7 @@ type QueryCoord interface {
ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)
GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error)
GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)
SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error)
LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error)
ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)

View File

@ -82,6 +82,10 @@ func (m *GrpcQueryCoordClient) GetSegmentInfo(ctx context.Context, in *querypb.G
return &querypb.GetSegmentInfoResponse{}, m.Err return &querypb.GetSegmentInfoResponse{}, m.Err
} }
func (m *GrpcQueryCoordClient) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { func (m *GrpcQueryCoordClient) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err return &commonpb.Status{}, m.Err
} }

View File

@ -61,6 +61,10 @@ func (m *GrpcQueryNodeClient) ReleaseCollection(ctx context.Context, in *querypb
return &commonpb.Status{}, m.Err return &commonpb.Status{}, m.Err
} }
func (m *GrpcQueryNodeClient) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryNodeClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { func (m *GrpcQueryNodeClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err return &commonpb.Status{}, m.Err
} }

View File

@ -77,6 +77,10 @@ func (q QueryNodeClient) ReleaseCollection(ctx context.Context, req *querypb.Rel
return q.grpcClient.ReleaseCollection(ctx, req) return q.grpcClient.ReleaseCollection(ctx, req)
} }
func (q QueryNodeClient) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return q.grpcClient.LoadPartitions(ctx, req)
}
func (q QueryNodeClient) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (q QueryNodeClient) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return q.grpcClient.ReleasePartitions(ctx, req) return q.grpcClient.ReleasePartitions(ctx, req)
} }

View File

@ -1116,9 +1116,7 @@ class TestCollectionOperation(TestcaseBase):
partition_w1.insert(cf.gen_default_list_data()) partition_w1.insert(cf.gen_default_list_data())
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load() collection_w.load()
error = {ct.err_code: 5, ct.err_msg: f'load the partition after load collection is not supported'} partition_w1.load()
partition_w1.load(check_task=CheckTasks.err_res,
check_items=error)
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_load_collection_release_partition(self): def test_load_collection_release_partition(self):
@ -1133,9 +1131,7 @@ class TestCollectionOperation(TestcaseBase):
partition_w1.insert(cf.gen_default_list_data()) partition_w1.insert(cf.gen_default_list_data())
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load() collection_w.load()
error = {ct.err_code: 1, ct.err_msg: f'releasing the partition after load collection is not supported'} partition_w1.release()
partition_w1.release(check_task=CheckTasks.err_res,
check_items=error)
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_load_collection_after_release_collection(self): def test_load_collection_after_release_collection(self):