Dynamic load/release partitions (#22655)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-03-20 14:55:57 +08:00 committed by GitHub
parent 7c633e9b9d
commit 1f718118e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
78 changed files with 3184 additions and 1574 deletions

View File

@ -272,6 +272,25 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
return ret.(*commonpb.Status), err return ret.(*commonpb.Status), err
} }
// SyncNewCreatedPartition notifies QueryCoord to sync new created partition if collection is loaded.
func (c *Client) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID(), commonpbutil.WithTargetID(c.sess.ServerID)),
)
ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.SyncNewCreatedPartition(ctx, req)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}
// GetPartitionStates gets the states of the specified partition. // GetPartitionStates gets the states of the specified partition.
func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
req = typeutil.Clone(req) req = typeutil.Clone(req)

View File

@ -114,6 +114,9 @@ func Test_NewClient(t *testing.T) {
r7, err := client.ReleasePartitions(ctx, nil) r7, err := client.ReleasePartitions(ctx, nil)
retCheck(retNotNil, r7, err) retCheck(retNotNil, r7, err)
r7, err = client.SyncNewCreatedPartition(ctx, nil)
retCheck(retNotNil, r7, err)
r8, err := client.ShowCollections(ctx, nil) r8, err := client.ShowCollections(ctx, nil)
retCheck(retNotNil, r8, err) retCheck(retNotNil, r8, err)

View File

@ -329,6 +329,11 @@ func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
return s.queryCoord.ReleasePartitions(ctx, req) return s.queryCoord.ReleasePartitions(ctx, req)
} }
// SyncNewCreatedPartition notifies QueryCoord to sync new created partition if collection is loaded.
func (s *Server) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
return s.queryCoord.SyncNewCreatedPartition(ctx, req)
}
// GetSegmentInfo gets the information of the specified segment from QueryCoord. // GetSegmentInfo gets the information of the specified segment from QueryCoord.
func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return s.queryCoord.GetSegmentInfo(ctx, req) return s.queryCoord.GetSegmentInfo(ctx, req)

View File

@ -213,6 +213,24 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
return ret.(*commonpb.Status), err return ret.(*commonpb.Status), err
} }
// LoadPartitions updates partitions meta info in QueryNode.
func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
req = typeutil.Clone(req)
commonpbutil.UpdateMsgBase(
req.GetBase(),
commonpbutil.FillMsgBaseFromClient(paramtable.GetNodeID()))
ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) {
if !funcutil.CheckCtxValid(ctx) {
return nil, ctx.Err()
}
return client.LoadPartitions(ctx, req)
})
if err != nil || ret == nil {
return nil, err
}
return ret.(*commonpb.Status), err
}
// ReleasePartitions releases the data of the specified partitions in QueryNode. // ReleasePartitions releases the data of the specified partitions in QueryNode.
func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
req = typeutil.Clone(req) req = typeutil.Clone(req)

View File

@ -78,6 +78,9 @@ func Test_NewClient(t *testing.T) {
r8, err := client.ReleaseCollection(ctx, nil) r8, err := client.ReleaseCollection(ctx, nil)
retCheck(retNotNil, r8, err) retCheck(retNotNil, r8, err)
r8, err = client.LoadPartitions(ctx, nil)
retCheck(retNotNil, r8, err)
r9, err := client.ReleasePartitions(ctx, nil) r9, err := client.ReleasePartitions(ctx, nil)
retCheck(retNotNil, r9, err) retCheck(retNotNil, r9, err)

View File

@ -276,6 +276,11 @@ func (s *Server) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl
return s.querynode.ReleaseCollection(ctx, req) return s.querynode.ReleaseCollection(ctx, req)
} }
// LoadPartitions updates partitions meta info in QueryNode.
func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return s.querynode.LoadPartitions(ctx, req)
}
// ReleasePartitions releases the data of the specified partitions in QueryNode. // ReleasePartitions releases the data of the specified partitions in QueryNode.
func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
// ignore ctx // ignore ctx

View File

@ -94,6 +94,10 @@ func (m *MockQueryNode) ReleaseCollection(ctx context.Context, req *querypb.Rele
return m.status, m.err return m.status, m.err
} }
func (m *MockQueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return m.status, m.err
}
func (m *MockQueryNode) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (m *MockQueryNode) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return m.status, m.err return m.status, m.err
} }
@ -263,6 +267,13 @@ func Test_NewServer(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
}) })
t.Run("LoadPartitions", func(t *testing.T) {
req := &querypb.LoadPartitionsRequest{}
resp, err := server.LoadPartitions(ctx, req)
assert.Nil(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
})
t.Run("ReleasePartitions", func(t *testing.T) { t.Run("ReleasePartitions", func(t *testing.T) {
req := &querypb.ReleasePartitionsRequest{} req := &querypb.ReleasePartitionsRequest{}
resp, err := server.ReleasePartitions(ctx, req) resp, err := server.ReleasePartitions(ctx, req)

View File

@ -150,13 +150,13 @@ type IndexCoordCatalog interface {
} }
type QueryCoordCatalog interface { type QueryCoordCatalog interface {
SaveCollection(info *querypb.CollectionLoadInfo) error SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error
SavePartition(info ...*querypb.PartitionLoadInfo) error SavePartition(info ...*querypb.PartitionLoadInfo) error
SaveReplica(replica *querypb.Replica) error SaveReplica(replica *querypb.Replica) error
GetCollections() ([]*querypb.CollectionLoadInfo, error) GetCollections() ([]*querypb.CollectionLoadInfo, error)
GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error) GetPartitions() (map[int64][]*querypb.PartitionLoadInfo, error)
GetReplicas() ([]*querypb.Replica, error) GetReplicas() ([]*querypb.Replica, error)
ReleaseCollection(id int64) error ReleaseCollection(collection int64) error
ReleasePartition(collection int64, partitions ...int64) error ReleasePartition(collection int64, partitions ...int64) error
ReleaseReplicas(collectionID int64) error ReleaseReplicas(collectionID int64) error
ReleaseReplica(collection, replica int64) error ReleaseReplica(collection, replica int64) error

View File

@ -31,12 +31,12 @@ var (
Help: "number of collections", Help: "number of collections",
}, []string{}) }, []string{})
QueryCoordNumEntities = prometheus.NewGaugeVec( QueryCoordNumPartitions = prometheus.NewGaugeVec(
prometheus.GaugeOpts{ prometheus.GaugeOpts{
Namespace: milvusNamespace, Namespace: milvusNamespace,
Subsystem: typeutil.QueryCoordRole, Subsystem: typeutil.QueryCoordRole,
Name: "entity_num", Name: "partition_num",
Help: "number of entities", Help: "number of partitions",
}, []string{}) }, []string{})
QueryCoordLoadCount = prometheus.NewCounterVec( QueryCoordLoadCount = prometheus.NewCounterVec(
@ -97,7 +97,7 @@ var (
// RegisterQueryCoord registers QueryCoord metrics // RegisterQueryCoord registers QueryCoord metrics
func RegisterQueryCoord(registry *prometheus.Registry) { func RegisterQueryCoord(registry *prometheus.Registry) {
registry.MustRegister(QueryCoordNumCollections) registry.MustRegister(QueryCoordNumCollections)
registry.MustRegister(QueryCoordNumEntities) registry.MustRegister(QueryCoordNumPartitions)
registry.MustRegister(QueryCoordLoadCount) registry.MustRegister(QueryCoordLoadCount)
registry.MustRegister(QueryCoordReleaseCount) registry.MustRegister(QueryCoordReleaseCount)
registry.MustRegister(QueryCoordLoadLatency) registry.MustRegister(QueryCoordLoadLatency)

View File

@ -23,6 +23,7 @@ service QueryCoord {
rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {} rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {}
rpc LoadCollection(LoadCollectionRequest) returns (common.Status) {} rpc LoadCollection(LoadCollectionRequest) returns (common.Status) {}
rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {} rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {}
rpc SyncNewCreatedPartition(SyncNewCreatedPartitionRequest) returns (common.Status) {}
rpc GetPartitionStates(GetPartitionStatesRequest) returns (GetPartitionStatesResponse) {} rpc GetPartitionStates(GetPartitionStatesRequest) returns (GetPartitionStatesResponse) {}
rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {} rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {}
@ -55,6 +56,7 @@ service QueryNode {
rpc UnsubDmChannel(UnsubDmChannelRequest) returns (common.Status) {} rpc UnsubDmChannel(UnsubDmChannelRequest) returns (common.Status) {}
rpc LoadSegments(LoadSegmentsRequest) returns (common.Status) {} rpc LoadSegments(LoadSegmentsRequest) returns (common.Status) {}
rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {} rpc ReleaseCollection(ReleaseCollectionRequest) returns (common.Status) {}
rpc LoadPartitions(LoadPartitionsRequest) returns (common.Status) {}
rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {} rpc ReleasePartitions(ReleasePartitionsRequest) returns (common.Status) {}
rpc ReleaseSegments(ReleaseSegmentsRequest) returns (common.Status) {} rpc ReleaseSegments(ReleaseSegmentsRequest) returns (common.Status) {}
rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {} rpc GetSegmentInfo(GetSegmentInfoRequest) returns (GetSegmentInfoResponse) {}
@ -189,6 +191,12 @@ message ShardLeadersList { // All leaders of all replicas of one shard
repeated string node_addrs = 3; repeated string node_addrs = 3;
} }
message SyncNewCreatedPartitionRequest {
common.MsgBase base = 1;
int64 collectionID = 2;
int64 partitionID = 3;
}
//-----------------query node grpc request and response proto---------------- //-----------------query node grpc request and response proto----------------
message LoadMetaInfo { message LoadMetaInfo {
LoadType load_type = 1; LoadType load_type = 1;
@ -482,18 +490,19 @@ enum LoadStatus {
message CollectionLoadInfo { message CollectionLoadInfo {
int64 collectionID = 1; int64 collectionID = 1;
repeated int64 released_partitions = 2; repeated int64 released_partitions = 2; // Deprecated: No longer used; kept for compatibility.
int32 replica_number = 3; int32 replica_number = 3;
LoadStatus status = 4; LoadStatus status = 4;
map<int64, int64> field_indexID = 5; map<int64, int64> field_indexID = 5;
LoadType load_type = 6;
} }
message PartitionLoadInfo { message PartitionLoadInfo {
int64 collectionID = 1; int64 collectionID = 1;
int64 partitionID = 2; int64 partitionID = 2;
int32 replica_number = 3; int32 replica_number = 3; // Deprecated: No longer used; kept for compatibility.
LoadStatus status = 4; LoadStatus status = 4;
map<int64, int64> field_indexID = 5; map<int64, int64> field_indexID = 5; // Deprecated: No longer used; kept for compatibility.
} }
message Replica { message Replica {

File diff suppressed because it is too large Load Diff

View File

@ -1980,7 +1980,7 @@ func TestProxy(t *testing.T) {
Type: milvuspb.ShowType_InMemory, Type: milvuspb.ShowType_InMemory,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, resp.Status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// default partition // default partition
assert.Equal(t, 0, len(resp.PartitionNames)) assert.Equal(t, 0, len(resp.PartitionNames))

View File

@ -85,6 +85,10 @@ func (m *QueryNodeMock) ReleaseCollection(ctx context.Context, req *querypb.Rele
return nil, nil return nil, nil
} }
func (m *QueryNodeMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO // TODO
func (m *QueryNodeMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (m *QueryNodeMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return nil, nil return nil, nil

View File

@ -868,32 +868,6 @@ func (dpt *dropPartitionTask) PreExecute(ctx context.Context) error {
return err return err
} }
collID, err := globalMetaCache.GetCollectionID(ctx, dpt.GetCollectionName())
if err != nil {
return err
}
partID, err := globalMetaCache.GetPartitionID(ctx, dpt.GetCollectionName(), dpt.GetPartitionName())
if err != nil {
if err.Error() == ErrPartitionNotExist(dpt.GetPartitionName()).Error() {
return nil
}
return err
}
collLoaded, err := isCollectionLoaded(ctx, dpt.queryCoord, collID)
if err != nil {
return err
}
if collLoaded {
loaded, err := isPartitionLoaded(ctx, dpt.queryCoord, collID, []int64{partID})
if err != nil {
return err
}
if loaded {
return errors.New("partition cannot be dropped, partition is loaded, please release it first")
}
}
return nil return nil
} }
@ -1587,6 +1561,9 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error {
} }
partitionIDs = append(partitionIDs, partitionID) partitionIDs = append(partitionIDs, partitionID)
} }
if len(partitionIDs) == 0 {
return errors.New("failed to load partition, due to no partition specified")
}
request := &querypb.LoadPartitionsRequest{ request := &querypb.LoadPartitionsRequest{
Base: commonpbutil.UpdateMsgBase( Base: commonpbutil.UpdateMsgBase(
lpt.Base, lpt.Base,

View File

@ -1134,20 +1134,6 @@ func TestDropPartitionTask(t *testing.T) {
err = task.PreExecute(ctx) err = task.PreExecute(ctx)
assert.NotNil(t, err) assert.NotNil(t, err)
t.Run("get collectionID error", func(t *testing.T) {
mockCache := newMockCache()
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) {
return 1, nil
})
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 0, errors.New("error")
})
globalMetaCache = mockCache
task.PartitionName = "partition1"
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("partition not exist", func(t *testing.T) { t.Run("partition not exist", func(t *testing.T) {
task.PartitionName = "partition2" task.PartitionName = "partition2"
@ -1162,21 +1148,6 @@ func TestDropPartitionTask(t *testing.T) {
err = task.PreExecute(ctx) err = task.PreExecute(ctx)
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("get partition error", func(t *testing.T) {
task.PartitionName = "partition3"
mockCache := newMockCache()
mockCache.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) {
return 0, errors.New("error")
})
mockCache.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 1, nil
})
globalMetaCache = mockCache
err = task.PreExecute(ctx)
assert.Error(t, err)
})
} }
func TestHasPartitionTask(t *testing.T) { func TestHasPartitionTask(t *testing.T) {

View File

@ -19,13 +19,14 @@ package balance
import ( import (
"sort" "sort"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta" "github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/samber/lo"
"go.uber.org/zap"
) )
type RowCountBasedBalancer struct { type RowCountBasedBalancer struct {

View File

@ -69,6 +69,8 @@ func (suite *RowCountBasedBalancerTestSuite) SetupTest() {
distManager := meta.NewDistributionManager() distManager := meta.NewDistributionManager()
suite.mockScheduler = task.NewMockScheduler(suite.T()) suite.mockScheduler = task.NewMockScheduler(suite.T())
suite.balancer = NewRowCountBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget) suite.balancer = NewRowCountBasedBalancer(suite.mockScheduler, nodeManager, distManager, testMeta, testTarget)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe()
} }
func (suite *RowCountBasedBalancerTestSuite) TearDownTest() { func (suite *RowCountBasedBalancerTestSuite) TearDownTest() {
@ -257,6 +259,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalance() {
balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1) balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
for node, s := range c.distributions { for node, s := range c.distributions {
@ -359,6 +362,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOutboundNodes() {
balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1) balancer.targetMgr.UpdateCollectionCurrentTarget(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loaded collection.Status = querypb.LoadStatus_Loaded
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...))) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, append(c.nodes, c.notExistedNodes...)))
for node, s := range c.distributions { for node, s := range c.distributions {
@ -415,6 +419,7 @@ func (suite *RowCountBasedBalancerTestSuite) TestBalanceOnLoadingCollection() {
collection := utils.CreateTestCollection(1, 1) collection := utils.CreateTestCollection(1, 1)
collection.LoadPercentage = 100 collection.LoadPercentage = 100
collection.Status = querypb.LoadStatus_Loading collection.Status = querypb.LoadStatus_Loading
collection.LoadType = querypb.LoadType_LoadCollection
balancer.meta.CollectionManager.PutCollection(collection) balancer.meta.CollectionManager.PutCollection(collection)
balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes)) balancer.meta.ReplicaManager.Put(utils.CreateTestReplica(1, 1, c.nodes))
for node, s := range c.distributions { for node, s := range c.distributions {

View File

@ -74,6 +74,8 @@ func (suite *ChannelCheckerTestSuite) SetupTest() {
balancer := suite.createMockBalancer() balancer := suite.createMockBalancer()
suite.checker = NewChannelChecker(suite.meta, distManager, targetManager, balancer) suite.checker = NewChannelChecker(suite.meta, distManager, targetManager, balancer)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe()
} }
func (suite *ChannelCheckerTestSuite) TearDownTest() { func (suite *ChannelCheckerTestSuite) TearDownTest() {

View File

@ -74,6 +74,8 @@ func (suite *SegmentCheckerTestSuite) SetupTest() {
balancer := suite.createMockBalancer() balancer := suite.createMockBalancer()
suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer) suite.checker = NewSegmentChecker(suite.meta, distManager, targetManager, balancer)
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil).Maybe()
} }
func (suite *SegmentCheckerTestSuite) TearDownTest() { func (suite *SegmentCheckerTestSuite) TearDownTest() {

View File

@ -26,4 +26,5 @@ var (
ErrCollectionLoaded = errors.New("CollectionLoaded") ErrCollectionLoaded = errors.New("CollectionLoaded")
ErrLoadParameterMismatched = errors.New("LoadParameterMismatched") ErrLoadParameterMismatched = errors.New("LoadParameterMismatched")
ErrNoEnoughNode = errors.New("NoEnoughNode") ErrNoEnoughNode = errors.New("NoEnoughNode")
ErrPartitionNotInTarget = errors.New("PartitionNotInLoadingTarget")
) )

View File

@ -18,20 +18,6 @@ package job
import ( import (
"context" "context"
"fmt"
"time"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
// Job is request of loading/releasing collection/partitions, // Job is request of loading/releasing collection/partitions,
@ -106,439 +92,3 @@ func (job *BaseJob) PreExecute() error {
} }
func (job *BaseJob) PostExecute() {} func (job *BaseJob) PostExecute() {}
type LoadCollectionJob struct {
*BaseJob
req *querypb.LoadCollectionRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadCollectionJob(
ctx context.Context,
req *querypb.LoadCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadCollectionJob {
return &LoadCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadCollectionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
if job.meta.Exist(req.GetCollectionID()) {
old := job.meta.GetCollection(req.GetCollectionID())
if old == nil {
msg := "load the partition after load collection is not supported"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if old.GetReplicaNumber() != req.GetReplicaNumber() {
msg := fmt.Sprintf("collection with different replica number %d existed, release this collection first before changing its replica number",
job.meta.GetReplicaNumber(req.GetCollectionID()),
)
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(old.GetFieldIndexID(), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
old.GetFieldIndexID())
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return ErrCollectionLoaded
}
return nil
}
func (job *LoadCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// Clear stale replicas
err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to clear stale replicas", zap.Error(err))
return err
}
// Create replicas
replicas, err := utils.SpawnReplicasWithRG(job.meta,
req.GetCollectionID(),
req.GetResourceGroups(),
req.GetReplicaNumber(),
)
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created",
zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()),
zap.String("resourceGroup", replica.GetResourceGroup()))
}
// Fetch channels and segments from DataCoord
partitionIDs, err := job.broker.GetPartitions(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to get partitions from RootCoord"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
// It's safe here to call UpdateCollectionNextTargetWithPartitions, as the collection not existing
err = job.targetMgr.UpdateCollectionNextTargetWithPartitions(req.GetCollectionID(), partitionIDs...)
if err != nil {
msg := "failed to update next targets for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.CollectionManager.PutCollection(&meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
},
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
})
if err != nil {
msg := "failed to store collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
return nil
}
func (job *LoadCollectionJob) PostExecute() {
if job.Error() != nil && !job.meta.Exist(job.CollectionID()) {
job.meta.ReplicaManager.RemoveCollection(job.CollectionID())
job.targetMgr.RemoveCollection(job.req.GetCollectionID())
}
}
type ReleaseCollectionJob struct {
*BaseJob
req *querypb.ReleaseCollectionRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleaseCollectionJob(ctx context.Context,
req *querypb.ReleaseCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleaseCollectionJob {
return &ReleaseCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleaseCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove collection"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove replicas"
log.Warn(msg, zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
waitCollectionReleased(job.dist, req.GetCollectionID())
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
return nil
}
type LoadPartitionJob struct {
*BaseJob
req *querypb.LoadPartitionsRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadPartitionJob(
ctx context.Context,
req *querypb.LoadPartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadPartitionJob {
return &LoadPartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadPartitionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
if job.meta.Exist(req.GetCollectionID()) {
old := job.meta.GetCollection(req.GetCollectionID())
if old != nil {
msg := "load the partition after load collection is not supported"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if job.meta.GetReplicaNumber(req.GetCollectionID()) != req.GetReplicaNumber() {
msg := "collection with different replica number existed, release this collection first before changing its replica number"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(job.meta.GetFieldIndex(req.GetCollectionID()), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
job.meta.GetFieldIndex(req.GetCollectionID()))
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
// Check whether one of the given partitions not loaded
for _, partitionID := range req.GetPartitionIDs() {
partition := job.meta.GetPartition(partitionID)
if partition == nil {
msg := fmt.Sprintf("some partitions %v of collection %v has been loaded into QueryNode, please release partitions firstly",
req.GetPartitionIDs(),
req.GetCollectionID())
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
}
return ErrCollectionLoaded
}
return nil
}
func (job *LoadPartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// Clear stale replicas
err := job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to clear stale replicas", zap.Error(err))
return err
}
// Create replicas
replicas, err := utils.SpawnReplicasWithRG(job.meta,
req.GetCollectionID(),
req.GetResourceGroups(),
req.GetReplicaNumber(),
)
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created",
zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()),
zap.String("resourceGroup", replica.GetResourceGroup()))
}
// It's safe here to call UpdateCollectionNextTargetWithPartitions, as the collection not existing
err = job.targetMgr.UpdateCollectionNextTargetWithPartitions(req.GetCollectionID(), req.GetPartitionIDs()...)
if err != nil {
msg := "failed to update next targets for collection"
log.Error(msg,
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
zap.Error(err))
return utils.WrapError(msg, err)
}
partitions := lo.Map(req.GetPartitionIDs(), func(partition int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: partition,
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
},
CreatedAt: time.Now(),
}
})
err = job.meta.CollectionManager.PutPartition(partitions...)
if err != nil {
msg := "failed to store partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
return nil
}
func (job *LoadPartitionJob) PostExecute() {
if job.Error() != nil && !job.meta.Exist(job.CollectionID()) {
job.meta.ReplicaManager.RemoveCollection(job.CollectionID())
job.targetMgr.RemoveCollection(job.req.GetCollectionID())
}
}
type ReleasePartitionJob struct {
*BaseJob
req *querypb.ReleasePartitionsRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleasePartitionJob(ctx context.Context,
req *querypb.ReleasePartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleasePartitionJob {
return &ReleasePartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleasePartitionJob) PreExecute() error {
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", job.req.GetCollectionID()),
)
if job.meta.CollectionManager.GetLoadType(job.req.GetCollectionID()) == querypb.LoadType_LoadCollection {
msg := "releasing some partitions after load collection is not supported"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return nil
}
func (job *ReleasePartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
)
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID())
partitionIDs := typeutil.NewUniqueSet(req.GetPartitionIDs()...)
toRelease := make([]int64, 0)
for _, partition := range loadedPartitions {
if partitionIDs.Contain(partition.GetPartitionID()) {
toRelease = append(toRelease, partition.GetPartitionID())
}
}
if len(toRelease) == len(loadedPartitions) { // All partitions are released, clear all
log.Info("release partitions covers all partitions, will remove the whole collection")
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to remove replicas", zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
waitCollectionReleased(job.dist, req.GetCollectionID())
} else {
err := job.meta.CollectionManager.RemovePartition(toRelease...)
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.targetMgr.RemovePartition(req.GetCollectionID(), toRelease...)
waitCollectionReleased(job.dist, req.GetCollectionID(), toRelease...)
}
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
return nil
}

View File

@ -0,0 +1,397 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"fmt"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
type LoadCollectionJob struct {
*BaseJob
req *querypb.LoadCollectionRequest
undo *UndoList
dist *meta.DistributionManager
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadCollectionJob(
ctx context.Context,
req *querypb.LoadCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
cluster session.Cluster,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadCollectionJob {
return &LoadCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
undo: NewUndoList(ctx, meta, cluster, targetMgr, targetObserver),
dist: dist,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadCollectionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
collection := job.meta.GetCollection(req.GetCollectionID())
if collection == nil {
return nil
}
if collection.GetReplicaNumber() != req.GetReplicaNumber() {
msg := fmt.Sprintf("collection with different replica number %d existed, release this collection first before changing its replica number",
job.meta.GetReplicaNumber(req.GetCollectionID()),
)
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(collection.GetFieldIndexID(), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
collection.GetFieldIndexID())
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return nil
}
func (job *LoadCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// 1. Fetch target partitions
partitionIDs, err := job.broker.GetPartitions(job.ctx, req.GetCollectionID())
if err != nil {
msg := "failed to get partitions from RootCoord"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
lackPartitionIDs := lo.FilterMap(partitionIDs, func(partID int64, _ int) (int64, bool) {
return partID, !lo.Contains(loadedPartitionIDs, partID)
})
if len(lackPartitionIDs) == 0 {
return ErrCollectionLoaded
}
job.undo.CollectionID = req.GetCollectionID()
job.undo.LackPartitions = lackPartitionIDs
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
// 2. loadPartitions on QueryNodes
err = loadPartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), lackPartitionIDs...)
if err != nil {
return err
}
job.undo.PartitionsLoaded = true
// 3. update next target
_, err = job.targetObserver.UpdateNextTarget(req.GetCollectionID(), partitionIDs...)
if err != nil {
msg := "failed to update next target"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.undo.TargetUpdated = true
colExisted := job.meta.CollectionManager.Exist(req.GetCollectionID())
if !colExisted {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
}
// 4. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID())
if len(replicas) == 0 {
replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber())
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created", zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()), zap.String("resourceGroup", replica.GetResourceGroup()))
}
job.undo.NewReplicaCreated = true
}
// 5. put collection/partitions meta
partitions := lo.Map(lackPartitionIDs, func(partID int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: partID,
Status: querypb.LoadStatus_Loading,
},
CreatedAt: time.Now(),
}
})
collection := &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
LoadType: querypb.LoadType_LoadCollection,
},
CreatedAt: time.Now(),
}
err = job.meta.CollectionManager.PutCollection(collection, partitions...)
if err != nil {
msg := "failed to store collection and partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
if !colExisted {
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
}
metrics.QueryCoordNumPartitions.WithLabelValues().Add(float64(len(partitions)))
return nil
}
func (job *LoadCollectionJob) PostExecute() {
if job.Error() != nil && !errors.Is(job.Error(), ErrCollectionLoaded) {
job.undo.RollBack()
}
}
type LoadPartitionJob struct {
*BaseJob
req *querypb.LoadPartitionsRequest
undo *UndoList
dist *meta.DistributionManager
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
broker meta.Broker
nodeMgr *session.NodeManager
}
func NewLoadPartitionJob(
ctx context.Context,
req *querypb.LoadPartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
cluster session.Cluster,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
broker meta.Broker,
nodeMgr *session.NodeManager,
) *LoadPartitionJob {
return &LoadPartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
undo: NewUndoList(ctx, meta, cluster, targetMgr, targetObserver),
dist: dist,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
broker: broker,
nodeMgr: nodeMgr,
}
}
func (job *LoadPartitionJob) PreExecute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if req.GetReplicaNumber() <= 0 {
log.Info("request doesn't indicate the number of replicas, set it to 1",
zap.Int32("replicaNumber", req.GetReplicaNumber()))
req.ReplicaNumber = 1
}
collection := job.meta.GetCollection(req.GetCollectionID())
if collection == nil {
return nil
}
if collection.GetReplicaNumber() != req.GetReplicaNumber() {
msg := "collection with different replica number existed, release this collection first before changing its replica number"
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
} else if !typeutil.MapEqual(collection.GetFieldIndexID(), req.GetFieldIndexID()) {
msg := fmt.Sprintf("collection with different index %v existed, release this collection first before changing its index",
job.meta.GetFieldIndex(req.GetCollectionID()))
log.Warn(msg)
return utils.WrapError(msg, ErrLoadParameterMismatched)
}
return nil
}
func (job *LoadPartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
meta.GlobalFailedLoadCache.Remove(req.GetCollectionID())
// 1. Fetch target partitions
loadedPartitionIDs := lo.Map(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()),
func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
lackPartitionIDs := lo.FilterMap(req.GetPartitionIDs(), func(partID int64, _ int) (int64, bool) {
return partID, !lo.Contains(loadedPartitionIDs, partID)
})
if len(lackPartitionIDs) == 0 {
return ErrCollectionLoaded
}
job.undo.CollectionID = req.GetCollectionID()
job.undo.LackPartitions = lackPartitionIDs
log.Info("find partitions to load", zap.Int64s("partitions", lackPartitionIDs))
// 2. loadPartitions on QueryNodes
err := loadPartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), lackPartitionIDs...)
if err != nil {
return err
}
job.undo.PartitionsLoaded = true
// 3. update next target
_, err = job.targetObserver.UpdateNextTarget(req.GetCollectionID(), append(loadedPartitionIDs, lackPartitionIDs...)...)
if err != nil {
msg := "failed to update next target"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.undo.TargetUpdated = true
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
// Clear stale replicas, https://github.com/milvus-io/milvus/issues/20444
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to clear stale replicas"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
}
// 4. create replica if not exist
replicas := job.meta.ReplicaManager.GetByCollection(req.GetCollectionID())
if len(replicas) == 0 {
replicas, err = utils.SpawnReplicasWithRG(job.meta, req.GetCollectionID(), req.GetResourceGroups(), req.GetReplicaNumber())
if err != nil {
msg := "failed to spawn replica for collection"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
for _, replica := range replicas {
log.Info("replica created", zap.Int64("replicaID", replica.GetID()),
zap.Int64s("nodes", replica.GetNodes()), zap.String("resourceGroup", replica.GetResourceGroup()))
}
job.undo.NewReplicaCreated = true
}
// 5. put collection/partitions meta
partitions := lo.Map(lackPartitionIDs, func(partID int64, _ int) *meta.Partition {
return &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: partID,
Status: querypb.LoadStatus_Loading,
},
CreatedAt: time.Now(),
}
})
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
collection := &meta.Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: req.GetCollectionID(),
ReplicaNumber: req.GetReplicaNumber(),
Status: querypb.LoadStatus_Loading,
FieldIndexID: req.GetFieldIndexID(),
LoadType: querypb.LoadType_LoadPartition,
},
CreatedAt: time.Now(),
}
err = job.meta.CollectionManager.PutCollection(collection, partitions...)
if err != nil {
msg := "failed to store collection and partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
metrics.QueryCoordNumCollections.WithLabelValues().Inc()
} else { // collection exists, put partitions only
err = job.meta.CollectionManager.PutPartition(partitions...)
if err != nil {
msg := "failed to store partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
}
metrics.QueryCoordNumPartitions.WithLabelValues().Add(float64(len(partitions)))
return nil
}
func (job *LoadPartitionJob) PostExecute() {
if job.Error() != nil && !errors.Is(job.Error(), ErrCollectionLoaded) {
job.undo.RollBack()
}
}

View File

@ -0,0 +1,176 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
)
type ReleaseCollectionJob struct {
*BaseJob
req *querypb.ReleaseCollectionRequest
dist *meta.DistributionManager
meta *meta.Meta
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleaseCollectionJob(ctx context.Context,
req *querypb.ReleaseCollectionRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleaseCollectionJob {
return &ReleaseCollectionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleaseCollectionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(zap.Int64("collectionID", req.GetCollectionID()))
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
lenPartitions := len(job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID()))
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove collection"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to remove replicas"
log.Warn(msg, zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
waitCollectionReleased(job.dist, req.GetCollectionID())
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
metrics.QueryCoordNumPartitions.WithLabelValues().Sub(float64(lenPartitions))
return nil
}
type ReleasePartitionJob struct {
*BaseJob
releasePartitionsOnly bool
req *querypb.ReleasePartitionsRequest
dist *meta.DistributionManager
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewReleasePartitionJob(ctx context.Context,
req *querypb.ReleasePartitionsRequest,
dist *meta.DistributionManager,
meta *meta.Meta,
cluster session.Cluster,
targetMgr *meta.TargetManager,
targetObserver *observers.TargetObserver,
) *ReleasePartitionJob {
return &ReleasePartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
dist: dist,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (job *ReleasePartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("partitionIDs", req.GetPartitionIDs()),
)
if !job.meta.CollectionManager.Exist(req.GetCollectionID()) {
log.Info("release collection end, the collection has not been loaded into QueryNode")
return nil
}
loadedPartitions := job.meta.CollectionManager.GetPartitionsByCollection(req.GetCollectionID())
toRelease := lo.FilterMap(loadedPartitions, func(partition *meta.Partition, _ int) (int64, bool) {
return partition.GetPartitionID(), lo.Contains(req.GetPartitionIDs(), partition.GetPartitionID())
})
// If all partitions are released and LoadType is LoadPartition, clear all
if len(toRelease) == len(loadedPartitions) &&
job.meta.GetLoadType(req.GetCollectionID()) == querypb.LoadType_LoadPartition {
log.Info("release partitions covers all partitions, will remove the whole collection")
err := job.meta.CollectionManager.RemoveCollection(req.GetCollectionID())
if err != nil {
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
err = job.meta.ReplicaManager.RemoveCollection(req.GetCollectionID())
if err != nil {
log.Warn("failed to remove replicas", zap.Error(err))
}
job.targetMgr.RemoveCollection(req.GetCollectionID())
job.targetObserver.ReleaseCollection(req.GetCollectionID())
metrics.QueryCoordNumCollections.WithLabelValues().Dec()
waitCollectionReleased(job.dist, req.GetCollectionID())
} else {
err := releasePartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), toRelease...)
if err != nil {
loadPartitions(job.ctx, job.meta, job.cluster, true, req.GetCollectionID(), toRelease...)
return err
}
err = job.meta.CollectionManager.RemovePartition(toRelease...)
if err != nil {
loadPartitions(job.ctx, job.meta, job.cluster, true, req.GetCollectionID(), toRelease...)
msg := "failed to release partitions from store"
log.Warn(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
job.targetMgr.RemovePartition(req.GetCollectionID(), toRelease...)
waitCollectionReleased(job.dist, req.GetCollectionID(), toRelease...)
}
metrics.QueryCoordNumPartitions.WithLabelValues().Sub(float64(len(toRelease)))
return nil
}

View File

@ -0,0 +1,103 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"time"
"github.com/cockroachdb/errors"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
)
type SyncNewCreatedPartitionJob struct {
*BaseJob
req *querypb.SyncNewCreatedPartitionRequest
meta *meta.Meta
cluster session.Cluster
}
func NewSyncNewCreatedPartitionJob(
ctx context.Context,
req *querypb.SyncNewCreatedPartitionRequest,
meta *meta.Meta,
cluster session.Cluster,
) *SyncNewCreatedPartitionJob {
return &SyncNewCreatedPartitionJob{
BaseJob: NewBaseJob(ctx, req.Base.GetMsgID(), req.GetCollectionID()),
req: req,
meta: meta,
cluster: cluster,
}
}
func (job *SyncNewCreatedPartitionJob) PreExecute() error {
// check if collection not load or loadType is loadPartition
collection := job.meta.GetCollection(job.req.GetCollectionID())
if collection == nil || collection.GetLoadType() == querypb.LoadType_LoadPartition {
return ErrPartitionNotInTarget
}
// check if partition already existed
if partition := job.meta.GetPartition(job.req.GetPartitionID()); partition != nil {
return ErrPartitionNotInTarget
}
return nil
}
func (job *SyncNewCreatedPartitionJob) Execute() error {
req := job.req
log := log.Ctx(job.ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64("partitionID", req.GetPartitionID()),
)
err := loadPartitions(job.ctx, job.meta, job.cluster, false, req.GetCollectionID(), req.GetPartitionID())
if err != nil {
return err
}
partition := &meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: req.GetCollectionID(),
PartitionID: req.GetPartitionID(),
Status: querypb.LoadStatus_Loaded,
},
LoadPercentage: 100,
CreatedAt: time.Now(),
}
err = job.meta.CollectionManager.PutPartition(partition)
if err != nil {
msg := "failed to store partitions"
log.Error(msg, zap.Error(err))
return utils.WrapError(msg, err)
}
return nil
}
func (job *SyncNewCreatedPartitionJob) PostExecute() {
if job.Error() != nil && !errors.Is(job.Error(), ErrPartitionNotInTarget) {
releasePartitions(job.ctx, job.meta, job.cluster, true, job.req.GetCollectionID(), job.req.GetPartitionID())
}
}

View File

@ -21,10 +21,11 @@ import (
"testing" "testing"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/samber/lo"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
@ -33,6 +34,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/observers" "github.com/milvus-io/milvus/internal/querycoordv2/observers"
. "github.com/milvus-io/milvus/internal/querycoordv2/params" . "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
) )
@ -56,6 +58,7 @@ type JobSuite struct {
store meta.Store store meta.Store
dist *meta.DistributionManager dist *meta.DistributionManager
meta *meta.Meta meta *meta.Meta
cluster *session.MockCluster
targetMgr *meta.TargetManager targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver targetObserver *observers.TargetObserver
broker *meta.MockBroker broker *meta.MockBroker
@ -70,8 +73,8 @@ func (suite *JobSuite) SetupSuite() {
suite.collections = []int64{1000, 1001} suite.collections = []int64{1000, 1001}
suite.partitions = map[int64][]int64{ suite.partitions = map[int64][]int64{
1000: {100, 101}, 1000: {100, 101, 102},
1001: {102, 103}, 1001: {103, 104, 105},
} }
suite.channels = map[int64][]string{ suite.channels = map[int64][]string{
1000: {"1000-dmc0", "1000-dmc1"}, 1000: {"1000-dmc0", "1000-dmc1"},
@ -81,10 +84,12 @@ func (suite *JobSuite) SetupSuite() {
1000: { 1000: {
100: {1, 2}, 100: {1, 2},
101: {3, 4}, 101: {3, 4},
102: {5, 6},
}, },
1001: { 1001: {
102: {5, 6},
103: {7, 8}, 103: {7, 8},
104: {9, 10},
105: {11, 12},
}, },
} }
suite.loadTypes = map[int64]querypb.LoadType{ suite.loadTypes = map[int64]querypb.LoadType{
@ -115,6 +120,14 @@ func (suite *JobSuite) SetupSuite() {
Return(vChannels, segmentBinlogs, nil) Return(vChannels, segmentBinlogs, nil)
} }
} }
suite.cluster = session.NewMockCluster(suite.T())
suite.cluster.EXPECT().
LoadPartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
suite.cluster.EXPECT().
ReleasePartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
} }
func (suite *JobSuite) SetupTest() { func (suite *JobSuite) SetupTest() {
@ -140,6 +153,7 @@ func (suite *JobSuite) SetupTest() {
suite.dist, suite.dist,
suite.broker, suite.broker,
) )
suite.targetObserver.Start(context.Background())
suite.scheduler = NewScheduler() suite.scheduler = NewScheduler()
suite.scheduler.Start(context.Background()) suite.scheduler.Start(context.Background())
@ -160,19 +174,14 @@ func (suite *JobSuite) SetupTest() {
func (suite *JobSuite) TearDownTest() { func (suite *JobSuite) TearDownTest() {
suite.kv.Close() suite.kv.Close()
suite.scheduler.Stop() suite.scheduler.Stop()
suite.targetObserver.Stop()
} }
func (suite *JobSuite) BeforeTest(suiteName, testName string) { func (suite *JobSuite) BeforeTest(suiteName, testName string) {
switch testName { for collection, partitions := range suite.partitions {
case "TestLoadCollection": suite.broker.EXPECT().
for collection, partitions := range suite.partitions { GetPartitions(mock.Anything, collection).
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection { Return(partitions, nil)
continue
}
suite.broker.EXPECT().
GetPartitions(mock.Anything, collection).
Return(partitions, nil)
}
} }
} }
@ -195,7 +204,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -204,7 +215,7 @@ func (suite *JobSuite) TestLoadCollection() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection) suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load again // Test load again
@ -220,7 +231,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -243,7 +256,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -268,13 +283,15 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.ErrorIs(err, ErrCollectionLoaded)
} }
suite.meta.ResourceManager.AddResourceGroup("rg1") suite.meta.ResourceManager.AddResourceGroup("rg1")
@ -292,7 +309,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -302,7 +321,7 @@ func (suite *JobSuite) TestLoadCollection() {
// Load with 3 replica on 3 rg // Load with 3 replica on 3 rg
req = &querypb.LoadCollectionRequest{ req = &querypb.LoadCollectionRequest{
CollectionID: 1002, CollectionID: 1001,
ReplicaNumber: 3, ReplicaNumber: 3,
ResourceGroups: []string{"rg1", "rg2", "rg3"}, ResourceGroups: []string{"rg1", "rg2", "rg3"},
} }
@ -311,7 +330,9 @@ func (suite *JobSuite) TestLoadCollection() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -338,7 +359,9 @@ func (suite *JobSuite) TestLoadCollectionWithReplicas() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -368,7 +391,9 @@ func (suite *JobSuite) TestLoadCollectionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -377,7 +402,7 @@ func (suite *JobSuite) TestLoadCollectionWithDiffIndex() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...) suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load with different index // Test load with different index
@ -396,7 +421,9 @@ func (suite *JobSuite) TestLoadCollectionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -425,7 +452,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -434,7 +463,7 @@ func (suite *JobSuite) TestLoadPartition() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...) suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load partition again // Test load partition again
@ -453,7 +482,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -478,7 +509,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -496,20 +529,22 @@ func (suite *JobSuite) TestLoadPartition() {
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
CollectionID: collection, CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 200), PartitionIDs: append(suite.partitions[collection], 200),
ReplicaNumber: 3, ReplicaNumber: 1,
} }
job := NewLoadPartitionJob( job := NewLoadPartitionJob(
ctx, ctx,
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.NoError(err)
} }
// Test load collection while partitions exists // Test load collection while partitions exists
@ -527,13 +562,15 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.ErrorIs(err, ErrCollectionLoaded)
} }
suite.meta.ResourceManager.AddResourceGroup("rg1") suite.meta.ResourceManager.AddResourceGroup("rg1")
@ -541,9 +578,11 @@ func (suite *JobSuite) TestLoadPartition() {
suite.meta.ResourceManager.AddResourceGroup("rg3") suite.meta.ResourceManager.AddResourceGroup("rg3")
// test load 3 replica in 1 rg, should pass rg check // test load 3 replica in 1 rg, should pass rg check
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(999)).Return([]int64{888}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(999), int64(888)).Return(nil, nil, nil)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
CollectionID: 100, CollectionID: 999,
PartitionIDs: []int64{1001}, PartitionIDs: []int64{888},
ReplicaNumber: 3, ReplicaNumber: 3,
ResourceGroups: []string{"rg1"}, ResourceGroups: []string{"rg1"},
} }
@ -552,7 +591,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -561,9 +602,11 @@ func (suite *JobSuite) TestLoadPartition() {
suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error()) suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error())
// test load 3 replica in 3 rg, should pass rg check // test load 3 replica in 3 rg, should pass rg check
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(999)).Return([]int64{888}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(999), int64(888)).Return(nil, nil, nil)
req = &querypb.LoadPartitionsRequest{ req = &querypb.LoadPartitionsRequest{
CollectionID: 102, CollectionID: 999,
PartitionIDs: []int64{1001}, PartitionIDs: []int64{888},
ReplicaNumber: 3, ReplicaNumber: 3,
ResourceGroups: []string{"rg1", "rg2", "rg3"}, ResourceGroups: []string{"rg1", "rg2", "rg3"},
} }
@ -572,7 +615,9 @@ func (suite *JobSuite) TestLoadPartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -581,6 +626,120 @@ func (suite *JobSuite) TestLoadPartition() {
suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error()) suite.Contains(err.Error(), meta.ErrNodeNotEnough.Error())
} }
func (suite *JobSuite) TestDynamicLoad() {
ctx := context.Background()
collection := suite.collections[0]
p0, p1, p2 := suite.partitions[collection][0], suite.partitions[collection][1], suite.partitions[collection][2]
newLoadPartJob := func(partitions ...int64) *LoadPartitionJob {
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: partitions,
ReplicaNumber: 1,
}
job := NewLoadPartitionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.broker,
suite.nodeMgr,
)
return job
}
newLoadColJob := func() *LoadCollectionJob {
req := &querypb.LoadCollectionRequest{
CollectionID: collection,
ReplicaNumber: 1,
}
job := NewLoadCollectionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
suite.broker,
suite.nodeMgr,
)
return job
}
// loaded: none
// action: load p0, p1, p2
// expect: p0, p1, p2 loaded
job := newLoadPartJob(p0, p1, p2)
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1, p2)
// loaded: p0, p1, p2
// action: load p0, p1, p2
// expect: do nothing, p0, p1, p2 loaded
job = newLoadPartJob(p0, p1, p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.ErrorIs(err, ErrCollectionLoaded)
suite.assertPartitionLoaded(collection)
// loaded: p0, p1
// action: load p2
// expect: p0, p1, p2 loaded
suite.releaseAll()
job = newLoadPartJob(p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1)
job = newLoadPartJob(p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p2)
// loaded: p0, p1
// action: load p1, p2
// expect: p0, p1, p2 loaded
suite.releaseAll()
job = newLoadPartJob(p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1)
job = newLoadPartJob(p1, p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p2)
// loaded: p0, p1
// action: load col
// expect: col loaded
suite.releaseAll()
job = newLoadPartJob(p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p0, p1)
colJob := newLoadColJob()
suite.scheduler.Add(colJob)
err = colJob.Wait()
suite.NoError(err)
suite.targetMgr.UpdateCollectionCurrentTarget(collection)
suite.assertPartitionLoaded(collection, p2)
}
func (suite *JobSuite) TestLoadPartitionWithReplicas() { func (suite *JobSuite) TestLoadPartitionWithReplicas() {
ctx := context.Background() ctx := context.Background()
@ -600,7 +759,9 @@ func (suite *JobSuite) TestLoadPartitionWithReplicas() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -631,7 +792,9 @@ func (suite *JobSuite) TestLoadPartitionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -640,7 +803,7 @@ func (suite *JobSuite) TestLoadPartitionWithDiffIndex() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...) suite.targetMgr.UpdateCollectionCurrentTarget(collection, suite.partitions[collection]...)
suite.assertLoaded(collection) suite.assertCollectionLoaded(collection)
} }
// Test load partition with different index // Test load partition with different index
@ -661,7 +824,9 @@ func (suite *JobSuite) TestLoadPartitionWithDiffIndex() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -692,7 +857,7 @@ func (suite *JobSuite) TestReleaseCollection() {
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertCollectionReleased(collection)
} }
// Test release again // Test release again
@ -711,7 +876,7 @@ func (suite *JobSuite) TestReleaseCollection() {
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertCollectionReleased(collection)
} }
} }
@ -731,18 +896,14 @@ func (suite *JobSuite) TestReleasePartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver, suite.targetObserver,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { suite.NoError(err)
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.assertPartitionReleased(collection, suite.partitions[collection]...)
suite.assertLoaded(collection)
} else {
suite.NoError(err)
suite.assertReleased(collection)
}
} }
// Test release again // Test release again
@ -756,18 +917,14 @@ func (suite *JobSuite) TestReleasePartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver, suite.targetObserver,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { suite.NoError(err)
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.assertPartitionReleased(collection, suite.partitions[collection]...)
suite.assertLoaded(collection)
} else {
suite.NoError(err)
suite.assertReleased(collection)
}
} }
// Test release partial partitions // Test release partial partitions
@ -783,24 +940,114 @@ func (suite *JobSuite) TestReleasePartition() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver, suite.targetObserver,
) )
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { suite.NoError(err)
suite.ErrorIs(err, ErrLoadParameterMismatched) suite.True(suite.meta.Exist(collection))
suite.assertLoaded(collection) partitions := suite.meta.GetPartitionsByCollection(collection)
} else { suite.Len(partitions, 1)
suite.NoError(err) suite.Equal(suite.partitions[collection][0], partitions[0].GetPartitionID())
suite.True(suite.meta.Exist(collection)) suite.assertPartitionReleased(collection, suite.partitions[collection][1:]...)
partitions := suite.meta.GetPartitionsByCollection(collection)
suite.Len(partitions, 1)
suite.Equal(suite.partitions[collection][0], partitions[0].GetPartitionID())
}
} }
} }
func (suite *JobSuite) TestDynamicRelease() {
ctx := context.Background()
col0, col1 := suite.collections[0], suite.collections[1]
p0, p1, p2 := suite.partitions[col0][0], suite.partitions[col0][1], suite.partitions[col0][2]
p3, p4, p5 := suite.partitions[col1][0], suite.partitions[col1][1], suite.partitions[col1][2]
newReleasePartJob := func(col int64, partitions ...int64) *ReleasePartitionJob {
req := &querypb.ReleasePartitionsRequest{
CollectionID: col,
PartitionIDs: partitions,
}
job := NewReleasePartitionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.cluster,
suite.targetMgr,
suite.targetObserver,
)
return job
}
newReleaseColJob := func(col int64) *ReleaseCollectionJob {
req := &querypb.ReleaseCollectionRequest{
CollectionID: col,
}
job := NewReleaseCollectionJob(
ctx,
req,
suite.dist,
suite.meta,
suite.targetMgr,
suite.targetObserver,
)
return job
}
// loaded: p0, p1, p2
// action: release p0
// expect: p0 released, p1, p2 loaded
suite.loadAll()
job := newReleasePartJob(col0, p0)
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
suite.assertPartitionReleased(col0, p0)
suite.assertPartitionLoaded(col0, p1, p2)
// loaded: p1, p2
// action: release p0, p1
// expect: p1 released, p2 loaded
job = newReleasePartJob(col0, p0, p1)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.assertPartitionReleased(col0, p0, p1)
suite.assertPartitionLoaded(col0, p2)
// loaded: p2
// action: release p2
// expect: loadType=col: col loaded, p2 released
job = newReleasePartJob(col0, p2)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.assertPartitionReleased(col0, p0, p1, p2)
suite.True(suite.meta.Exist(col0))
// loaded: p0, p1, p2
// action: release col
// expect: col released
suite.releaseAll()
suite.loadAll()
releaseColJob := newReleaseColJob(col0)
suite.scheduler.Add(releaseColJob)
err = releaseColJob.Wait()
suite.NoError(err)
suite.assertCollectionReleased(col0)
suite.assertPartitionReleased(col0, p0, p1, p2)
// loaded: p3, p4, p5
// action: release p3, p4, p5
// expect: loadType=partition: col released
suite.releaseAll()
suite.loadAll()
job = newReleasePartJob(col1, p3, p4, p5)
suite.scheduler.Add(job)
err = job.Wait()
suite.NoError(err)
suite.assertCollectionReleased(col1)
suite.assertPartitionReleased(col1, p3, p4, p5)
}
func (suite *JobSuite) TestLoadCollectionStoreFailed() { func (suite *JobSuite) TestLoadCollectionStoreFailed() {
// Store collection failed // Store collection failed
store := meta.NewMockStore(suite.T()) store := meta.NewMockStore(suite.T())
@ -818,14 +1065,10 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection { if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue continue
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
err := errors.New("failed to store collection") err := errors.New("failed to store collection")
store.EXPECT().SaveReplica(mock.Anything).Return(nil) store.EXPECT().SaveReplica(mock.Anything).Return(nil)
store.EXPECT().SaveCollection(&querypb.CollectionLoadInfo{ store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
CollectionID: collection,
ReplicaNumber: 1,
Status: querypb.LoadStatus_Loading,
}).Return(err)
store.EXPECT().ReleaseReplicas(collection).Return(nil) store.EXPECT().ReleaseReplicas(collection).Return(nil)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
@ -836,7 +1079,9 @@ func (suite *JobSuite) TestLoadCollectionStoreFailed() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -866,7 +1111,7 @@ func (suite *JobSuite) TestLoadPartitionStoreFailed() {
} }
store.EXPECT().SaveReplica(mock.Anything).Return(nil) store.EXPECT().SaveReplica(mock.Anything).Return(nil)
store.EXPECT().SavePartition(mock.Anything, mock.Anything).Return(err) store.EXPECT().SaveCollection(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(err)
store.EXPECT().ReleaseReplicas(collection).Return(nil) store.EXPECT().ReleaseReplicas(collection).Return(nil)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
@ -878,7 +1123,9 @@ func (suite *JobSuite) TestLoadPartitionStoreFailed() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -892,6 +1139,9 @@ func (suite *JobSuite) TestLoadCreateReplicaFailed() {
// Store replica failed // Store replica failed
suite.meta = meta.NewMeta(ErrorIDAllocator(), suite.store, session.NewNodeManager()) suite.meta = meta.NewMeta(ErrorIDAllocator(), suite.store, session.NewNodeManager())
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().
GetPartitions(mock.Anything, collection).
Return(suite.partitions[collection], nil)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
CollectionID: collection, CollectionID: collection,
} }
@ -900,7 +1150,9 @@ func (suite *JobSuite) TestLoadCreateReplicaFailed() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -910,6 +1162,59 @@ func (suite *JobSuite) TestLoadCreateReplicaFailed() {
} }
} }
func (suite *JobSuite) TestSyncNewCreatedPartition() {
newPartition := int64(999)
// test sync new created partition
suite.loadAll()
req := &querypb.SyncNewCreatedPartitionRequest{
CollectionID: suite.collections[0],
PartitionID: newPartition,
}
job := NewSyncNewCreatedPartitionJob(
context.Background(),
req,
suite.meta,
suite.cluster,
)
suite.scheduler.Add(job)
err := job.Wait()
suite.NoError(err)
partition := suite.meta.CollectionManager.GetPartition(newPartition)
suite.NotNil(partition)
suite.Equal(querypb.LoadStatus_Loaded, partition.GetStatus())
// test collection not loaded
req = &querypb.SyncNewCreatedPartitionRequest{
CollectionID: int64(888),
PartitionID: newPartition,
}
job = NewSyncNewCreatedPartitionJob(
context.Background(),
req,
suite.meta,
suite.cluster,
)
suite.scheduler.Add(job)
err = job.Wait()
suite.ErrorIs(err, ErrPartitionNotInTarget)
// test collection loaded, but its loadType is loadPartition
req = &querypb.SyncNewCreatedPartitionRequest{
CollectionID: suite.collections[1],
PartitionID: newPartition,
}
job = NewSyncNewCreatedPartitionJob(
context.Background(),
req,
suite.meta,
suite.cluster,
)
suite.scheduler.Add(job)
err = job.Wait()
suite.ErrorIs(err, ErrPartitionNotInTarget)
}
func (suite *JobSuite) loadAll() { func (suite *JobSuite) loadAll() {
ctx := context.Background() ctx := context.Background()
for _, collection := range suite.collections { for _, collection := range suite.collections {
@ -922,7 +1227,9 @@ func (suite *JobSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -932,6 +1239,7 @@ func (suite *JobSuite) loadAll() {
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection)) suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection)) suite.NotNil(suite.meta.GetCollection(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection) suite.targetMgr.UpdateCollectionCurrentTarget(collection)
} else { } else {
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
@ -943,7 +1251,9 @@ func (suite *JobSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -952,6 +1262,7 @@ func (suite *JobSuite) loadAll() {
suite.NoError(err) suite.NoError(err)
suite.EqualValues(1, suite.meta.GetReplicaNumber(collection)) suite.EqualValues(1, suite.meta.GetReplicaNumber(collection))
suite.True(suite.meta.Exist(collection)) suite.True(suite.meta.Exist(collection))
suite.NotNil(suite.meta.GetCollection(collection))
suite.NotNil(suite.meta.GetPartitionsByCollection(collection)) suite.NotNil(suite.meta.GetPartitionsByCollection(collection))
suite.targetMgr.UpdateCollectionCurrentTarget(collection) suite.targetMgr.UpdateCollectionCurrentTarget(collection)
} }
@ -975,24 +1286,43 @@ func (suite *JobSuite) releaseAll() {
suite.scheduler.Add(job) suite.scheduler.Add(job)
err := job.Wait() err := job.Wait()
suite.NoError(err) suite.NoError(err)
suite.assertReleased(collection) suite.assertCollectionReleased(collection)
} }
} }
func (suite *JobSuite) assertLoaded(collection int64) { func (suite *JobSuite) assertCollectionLoaded(collection int64) {
suite.True(suite.meta.Exist(collection)) suite.True(suite.meta.Exist(collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
} }
for _, partitions := range suite.segments[collection] { for _, segments := range suite.segments[collection] {
for _, segment := range partitions { for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget)) suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
} }
} }
} }
func (suite *JobSuite) assertReleased(collection int64) { func (suite *JobSuite) assertPartitionLoaded(collection int64, partitionIDs ...int64) {
suite.True(suite.meta.Exist(collection))
suite.NotEqual(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
for _, channel := range suite.channels[collection] {
suite.NotNil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
}
for partitionID, segments := range suite.segments[collection] {
if !lo.Contains(partitionIDs, partitionID) {
continue
}
suite.NotNil(suite.meta.GetPartition(partitionID))
for _, segment := range segments {
suite.NotNil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
}
}
}
func (suite *JobSuite) assertCollectionReleased(collection int64) {
suite.False(suite.meta.Exist(collection)) suite.False(suite.meta.Exist(collection))
suite.Equal(0, len(suite.meta.ReplicaManager.GetByCollection(collection)))
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget)) suite.Nil(suite.targetMgr.GetDmChannel(collection, channel, meta.CurrentTarget))
} }
@ -1003,6 +1333,16 @@ func (suite *JobSuite) assertReleased(collection int64) {
} }
} }
func (suite *JobSuite) assertPartitionReleased(collection int64, partitionIDs ...int64) {
for _, partition := range partitionIDs {
suite.Nil(suite.meta.GetPartition(partition))
segments := suite.segments[collection][partition]
for _, segment := range segments {
suite.Nil(suite.targetMgr.GetHistoricalSegment(collection, segment, meta.CurrentTarget))
}
}
}
func TestJob(t *testing.T) { func TestJob(t *testing.T) {
suite.Run(t, new(JobSuite)) suite.Run(t, new(JobSuite))
} }

View File

@ -0,0 +1,68 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package job
import (
"context"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/observers"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
)
type UndoList struct {
PartitionsLoaded bool // indicates if partitions loaded in QueryNodes during loading
TargetUpdated bool // indicates if target updated during loading
NewReplicaCreated bool // indicates if created new replicas during loading
CollectionID int64
LackPartitions []int64
ctx context.Context
meta *meta.Meta
cluster session.Cluster
targetMgr *meta.TargetManager
targetObserver *observers.TargetObserver
}
func NewUndoList(ctx context.Context, meta *meta.Meta,
cluster session.Cluster, targetMgr *meta.TargetManager, targetObserver *observers.TargetObserver) *UndoList {
return &UndoList{
ctx: ctx,
meta: meta,
cluster: cluster,
targetMgr: targetMgr,
targetObserver: targetObserver,
}
}
func (u *UndoList) RollBack() {
if u.PartitionsLoaded {
releasePartitions(u.ctx, u.meta, u.cluster, true, u.CollectionID, u.LackPartitions...)
}
if u.TargetUpdated {
if !u.meta.CollectionManager.Exist(u.CollectionID) {
u.targetMgr.RemoveCollection(u.CollectionID)
u.targetObserver.ReleaseCollection(u.CollectionID)
} else {
u.targetMgr.RemovePartition(u.CollectionID, u.LackPartitions...)
}
}
if u.NewReplicaCreated {
u.meta.ReplicaManager.RemoveCollection(u.CollectionID)
}
}

View File

@ -17,11 +17,17 @@
package job package job
import ( import (
"context"
"fmt"
"time" "time"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo" "github.com/samber/lo"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/util/typeutil"
) )
// waitCollectionReleased blocks until // waitCollectionReleased blocks until
@ -49,3 +55,57 @@ func waitCollectionReleased(dist *meta.DistributionManager, collection int64, pa
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
} }
} }
func loadPartitions(ctx context.Context, meta *meta.Meta, cluster session.Cluster,
ignoreErr bool, collection int64, partitions ...int64) error {
replicas := meta.ReplicaManager.GetByCollection(collection)
loadReq := &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadPartitions,
},
CollectionID: collection,
PartitionIDs: partitions,
}
for _, replica := range replicas {
for _, node := range replica.GetNodes() {
status, err := cluster.LoadPartitions(ctx, node, loadReq)
if ignoreErr {
continue
}
if err != nil {
return err
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("QueryNode failed to loadPartition, nodeID=%d, err=%s", node, status.GetReason())
}
}
}
return nil
}
func releasePartitions(ctx context.Context, meta *meta.Meta, cluster session.Cluster,
ignoreErr bool, collection int64, partitions ...int64) error {
replicas := meta.ReplicaManager.GetByCollection(collection)
releaseReq := &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ReleasePartitions,
},
CollectionID: collection,
PartitionIDs: partitions,
}
for _, replica := range replicas {
for _, node := range replica.GetNodes() {
status, err := cluster.ReleasePartitions(ctx, node, releaseReq)
if ignoreErr {
continue
}
if err != nil {
return err
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("QueryNode failed to releasePartitions, nodeID=%d, err=%s", node, status.GetReason())
}
}
}
return nil
}

View File

@ -17,15 +17,19 @@
package meta package meta
import ( import (
"context"
"sync" "sync"
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/merr"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
. "github.com/milvus-io/milvus/internal/util/typeutil" . "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/samber/lo"
) )
type Collection struct { type Collection struct {
@ -72,7 +76,7 @@ func NewCollectionManager(store Store) *CollectionManager {
// Recover recovers collections from kv store, // Recover recovers collections from kv store,
// panics if failed // panics if failed
func (m *CollectionManager) Recover() error { func (m *CollectionManager) Recover(broker Broker) error {
collections, err := m.store.GetCollections() collections, err := m.store.GetCollections()
if err != nil { if err != nil {
return err return err
@ -88,7 +92,6 @@ func (m *CollectionManager) Recover() error {
m.store.ReleaseCollection(collection.GetCollectionID()) m.store.ReleaseCollection(collection.GetCollectionID())
continue continue
} }
m.collections[collection.CollectionID] = &Collection{ m.collections[collection.CollectionID] = &Collection{
CollectionLoadInfo: collection, CollectionLoadInfo: collection,
} }
@ -104,94 +107,171 @@ func (m *CollectionManager) Recover() error {
m.store.ReleasePartition(collection, partitionIDs...) m.store.ReleasePartition(collection, partitionIDs...)
break break
} }
m.partitions[partition.PartitionID] = &Partition{ m.partitions[partition.PartitionID] = &Partition{
PartitionLoadInfo: partition, PartitionLoadInfo: partition,
} }
} }
} }
err = m.upgradeRecover(broker)
if err != nil {
log.Error("upgrade recover failed", zap.Error(err))
return err
}
return nil return nil
} }
func (m *CollectionManager) GetCollection(id UniqueID) *Collection { // upgradeRecover recovers from old version <= 2.2.x for compatibility.
m.rwmutex.RLock() func (m *CollectionManager) upgradeRecover(broker Broker) error {
defer m.rwmutex.RUnlock() for _, collection := range m.GetAllCollections() {
// It's a workaround to check if it is old CollectionLoadInfo because there's no
return m.collections[id] // loadType in old version, maybe we should use version instead.
} if collection.GetLoadType() == querypb.LoadType_UnKnownType {
partitionIDs, err := broker.GetPartitions(context.Background(), collection.GetCollectionID())
func (m *CollectionManager) GetPartition(id UniqueID) *Partition { if err != nil {
m.rwmutex.RLock() return err
defer m.rwmutex.RUnlock() }
partitions := lo.Map(partitionIDs, func(partitionID int64, _ int) *Partition {
return m.partitions[id] return &Partition{
} PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection.GetCollectionID(),
func (m *CollectionManager) GetLoadType(id UniqueID) querypb.LoadType { PartitionID: partitionID,
m.rwmutex.RLock() Status: querypb.LoadStatus_Loaded,
defer m.rwmutex.RUnlock() },
LoadPercentage: 100,
_, ok := m.collections[id] }
if ok { })
return querypb.LoadType_LoadCollection err = m.putPartition(partitions, true)
if err != nil {
return err
}
}
} }
if len(m.getPartitionsByCollection(id)) > 0 { for _, partition := range m.GetAllPartitions() {
return querypb.LoadType_LoadPartition // In old version, collection would NOT be stored if the partition existed.
if _, ok := m.collections[partition.GetCollectionID()]; !ok {
col := &Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: partition.GetCollectionID(),
ReplicaNumber: partition.GetReplicaNumber(),
Status: partition.GetStatus(),
FieldIndexID: partition.GetFieldIndexID(),
LoadType: querypb.LoadType_LoadPartition,
},
LoadPercentage: 100,
}
err := m.PutCollection(col)
if err != nil {
return err
}
}
}
return nil
}
func (m *CollectionManager) GetCollection(collectionID UniqueID) *Collection {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return m.collections[collectionID]
}
func (m *CollectionManager) GetPartition(partitionID UniqueID) *Partition {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
return m.partitions[partitionID]
}
func (m *CollectionManager) GetLoadType(collectionID UniqueID) querypb.LoadType {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
collection, ok := m.collections[collectionID]
if ok {
return collection.GetLoadType()
} }
return querypb.LoadType_UnKnownType return querypb.LoadType_UnKnownType
} }
func (m *CollectionManager) GetReplicaNumber(id UniqueID) int32 { func (m *CollectionManager) GetReplicaNumber(collectionID UniqueID) int32 {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
collection, ok := m.collections[id] collection, ok := m.collections[collectionID]
if ok { if ok {
return collection.GetReplicaNumber() return collection.GetReplicaNumber()
} }
partitions := m.getPartitionsByCollection(id) return -1
if len(partitions) > 0 { }
return partitions[0].GetReplicaNumber()
// GetCurrentLoadPercentage checks if collection is currently fully loaded.
func (m *CollectionManager) GetCurrentLoadPercentage(collectionID UniqueID) int32 {
m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
collection, ok := m.collections[collectionID]
if ok {
partitions := m.getPartitionsByCollection(collectionID)
if len(partitions) > 0 {
return lo.SumBy(partitions, func(partition *Partition) int32 {
return partition.LoadPercentage
}) / int32(len(partitions))
}
if collection.GetLoadType() == querypb.LoadType_LoadCollection {
// no partition exists
return 100
}
} }
return -1 return -1
} }
func (m *CollectionManager) GetLoadPercentage(id UniqueID) int32 { // GetCollectionLoadPercentage returns collection load percentage.
// Note: collection.LoadPercentage == 100 only means that it used to be fully loaded, and it is queryable,
// to check if it is fully loaded now, use GetCurrentLoadPercentage instead.
func (m *CollectionManager) GetCollectionLoadPercentage(collectionID UniqueID) int32 {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
collection, ok := m.collections[id] collection, ok := m.collections[collectionID]
if ok { if ok {
return collection.LoadPercentage return collection.LoadPercentage
} }
partitions := m.getPartitionsByCollection(id) return -1
if len(partitions) > 0 { }
return lo.SumBy(partitions, func(partition *Partition) int32 {
return partition.LoadPercentage func (m *CollectionManager) GetPartitionLoadPercentage(partitionID UniqueID) int32 {
}) / int32(len(partitions)) m.rwmutex.RLock()
defer m.rwmutex.RUnlock()
partition, ok := m.partitions[partitionID]
if ok {
return partition.LoadPercentage
} }
return -1 return -1
} }
func (m *CollectionManager) GetStatus(id UniqueID) querypb.LoadStatus { func (m *CollectionManager) GetStatus(collectionID UniqueID) querypb.LoadStatus {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
collection, ok := m.collections[id] collection, ok := m.collections[collectionID]
if ok { if !ok {
return collection.GetStatus()
}
partitions := m.getPartitionsByCollection(id)
if len(partitions) == 0 {
return querypb.LoadStatus_Invalid return querypb.LoadStatus_Invalid
} }
partitions := m.getPartitionsByCollection(collectionID)
for _, partition := range partitions { for _, partition := range partitions {
if partition.GetStatus() == querypb.LoadStatus_Loading { if partition.GetStatus() == querypb.LoadStatus_Loading {
return querypb.LoadStatus_Loading return querypb.LoadStatus_Loading
} }
} }
return querypb.LoadStatus_Loaded if len(partitions) > 0 {
return querypb.LoadStatus_Loaded
}
if collection.GetLoadType() == querypb.LoadType_LoadCollection {
return querypb.LoadStatus_Loaded
}
return querypb.LoadStatus_Invalid
} }
func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64 { func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64 {
@ -202,11 +282,7 @@ func (m *CollectionManager) GetFieldIndex(collectionID UniqueID) map[int64]int64
if ok { if ok {
return collection.GetFieldIndexID() return collection.GetFieldIndexID()
} }
partitions := m.getPartitionsByCollection(collectionID) return nil
if len(partitions) == 0 {
return nil
}
return partitions[0].GetFieldIndexID()
} }
// ContainAnyIndex returns true if the loaded collection contains one of the given indexes, // ContainAnyIndex returns true if the loaded collection contains one of the given indexes,
@ -228,31 +304,18 @@ func (m *CollectionManager) containIndex(collectionID, indexID int64) bool {
if ok { if ok {
return lo.Contains(lo.Values(collection.GetFieldIndexID()), indexID) return lo.Contains(lo.Values(collection.GetFieldIndexID()), indexID)
} }
partitions := m.getPartitionsByCollection(collectionID)
if len(partitions) == 0 {
return false
}
for _, partition := range partitions {
if lo.Contains(lo.Values(partition.GetFieldIndexID()), indexID) {
return true
}
}
return false return false
} }
func (m *CollectionManager) Exist(id UniqueID) bool { func (m *CollectionManager) Exist(collectionID UniqueID) bool {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
_, ok := m.collections[id] _, ok := m.collections[collectionID]
if ok { return ok
return true
}
partitions := m.getPartitionsByCollection(id)
return len(partitions) > 0
} }
// GetAll returns the collection ID of all loaded collections and partitions // GetAll returns the collection ID of all loaded collections
func (m *CollectionManager) GetAll() []int64 { func (m *CollectionManager) GetAll() []int64 {
m.rwmutex.RLock() m.rwmutex.RLock()
defer m.rwmutex.RUnlock() defer m.rwmutex.RUnlock()
@ -261,9 +324,6 @@ func (m *CollectionManager) GetAll() []int64 {
for _, collection := range m.collections { for _, collection := range m.collections {
ids.Insert(collection.GetCollectionID()) ids.Insert(collection.GetCollectionID())
} }
for _, partition := range m.partitions {
ids.Insert(partition.GetCollectionID())
}
return ids.Collect() return ids.Collect()
} }
@ -298,11 +358,11 @@ func (m *CollectionManager) getPartitionsByCollection(collectionID UniqueID) []*
return partitions return partitions
} }
func (m *CollectionManager) PutCollection(collection *Collection) error { func (m *CollectionManager) PutCollection(collection *Collection, partitions ...*Partition) error {
m.rwmutex.Lock() m.rwmutex.Lock()
defer m.rwmutex.Unlock() defer m.rwmutex.Unlock()
return m.putCollection(collection, true) return m.putCollection(true, collection, partitions...)
} }
func (m *CollectionManager) UpdateCollection(collection *Collection) error { func (m *CollectionManager) UpdateCollection(collection *Collection) error {
@ -314,7 +374,7 @@ func (m *CollectionManager) UpdateCollection(collection *Collection) error {
return merr.WrapErrCollectionNotFound(collection.GetCollectionID()) return merr.WrapErrCollectionNotFound(collection.GetCollectionID())
} }
return m.putCollection(collection, true) return m.putCollection(true, collection)
} }
func (m *CollectionManager) UpdateCollectionInMemory(collection *Collection) bool { func (m *CollectionManager) UpdateCollectionInMemory(collection *Collection) bool {
@ -326,17 +386,24 @@ func (m *CollectionManager) UpdateCollectionInMemory(collection *Collection) boo
return false return false
} }
m.putCollection(collection, false) m.putCollection(false, collection)
return true return true
} }
func (m *CollectionManager) putCollection(collection *Collection, withSave bool) error { func (m *CollectionManager) putCollection(withSave bool, collection *Collection, partitions ...*Partition) error {
if withSave { if withSave {
err := m.store.SaveCollection(collection.CollectionLoadInfo) partitionInfos := lo.Map(partitions, func(partition *Partition, _ int) *querypb.PartitionLoadInfo {
return partition.PartitionLoadInfo
})
err := m.store.SaveCollection(collection.CollectionLoadInfo, partitionInfos...)
if err != nil { if err != nil {
return err return err
} }
} }
for _, partition := range partitions {
partition.UpdatedAt = time.Now()
m.partitions[partition.GetPartitionID()] = partition
}
collection.UpdatedAt = time.Now() collection.UpdatedAt = time.Now()
m.collections[collection.CollectionID] = collection m.collections[collection.CollectionID] = collection
@ -399,25 +466,25 @@ func (m *CollectionManager) putPartition(partitions []*Partition, withSave bool)
return nil return nil
} }
func (m *CollectionManager) RemoveCollection(id UniqueID) error { // RemoveCollection removes collection and its partitions.
func (m *CollectionManager) RemoveCollection(collectionID UniqueID) error {
m.rwmutex.Lock() m.rwmutex.Lock()
defer m.rwmutex.Unlock() defer m.rwmutex.Unlock()
_, ok := m.collections[id] _, ok := m.collections[collectionID]
if ok { if ok {
err := m.store.ReleaseCollection(id) err := m.store.ReleaseCollection(collectionID)
if err != nil { if err != nil {
return err return err
} }
delete(m.collections, id) delete(m.collections, collectionID)
return nil for partID, partition := range m.partitions {
if partition.CollectionID == collectionID {
delete(m.partitions, partID)
}
}
} }
return nil
partitions := lo.Map(m.getPartitionsByCollection(id),
func(partition *Partition, _ int) int64 {
return partition.GetPartitionID()
})
return m.removePartition(partitions...)
} }
func (m *CollectionManager) RemovePartition(ids ...UniqueID) error { func (m *CollectionManager) RemovePartition(ids ...UniqueID) error {

View File

@ -21,12 +21,15 @@ import (
"testing" "testing"
"time" "time"
"github.com/samber/lo"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus/internal/kv" "github.com/milvus-io/milvus/internal/kv"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
. "github.com/milvus-io/milvus/internal/querycoordv2/params" . "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/stretchr/testify/suite"
) )
type CollectionManagerSuite struct { type CollectionManagerSuite struct {
@ -37,11 +40,13 @@ type CollectionManagerSuite struct {
partitions map[int64][]int64 // CollectionID -> PartitionIDs partitions map[int64][]int64 // CollectionID -> PartitionIDs
loadTypes []querypb.LoadType loadTypes []querypb.LoadType
replicaNumber []int32 replicaNumber []int32
loadPercentage []int32 colLoadPercent []int32
parLoadPercent map[int64][]int32
// Mocks // Mocks
kv kv.MetaKv kv kv.MetaKv
store Store store Store
broker *MockBroker
// Test object // Test object
mgr *CollectionManager mgr *CollectionManager
@ -50,19 +55,27 @@ type CollectionManagerSuite struct {
func (suite *CollectionManagerSuite) SetupSuite() { func (suite *CollectionManagerSuite) SetupSuite() {
Params.Init() Params.Init()
suite.collections = []int64{100, 101, 102} suite.collections = []int64{100, 101, 102, 103}
suite.partitions = map[int64][]int64{ suite.partitions = map[int64][]int64{
100: {10}, 100: {10},
101: {11, 12}, 101: {11, 12},
102: {13, 14, 15}, 102: {13, 14, 15},
103: {}, // not partition in this col
} }
suite.loadTypes = []querypb.LoadType{ suite.loadTypes = []querypb.LoadType{
querypb.LoadType_LoadCollection, querypb.LoadType_LoadCollection,
querypb.LoadType_LoadPartition, querypb.LoadType_LoadPartition,
querypb.LoadType_LoadCollection, querypb.LoadType_LoadCollection,
querypb.LoadType_LoadCollection,
}
suite.replicaNumber = []int32{1, 2, 3, 1}
suite.colLoadPercent = []int32{0, 50, 100, 100}
suite.parLoadPercent = map[int64][]int32{
100: {0},
101: {0, 100},
102: {100, 100, 100},
103: {},
} }
suite.replicaNumber = []int32{1, 2, 3}
suite.loadPercentage = []int32{0, 50, 100}
} }
func (suite *CollectionManagerSuite) SetupTest() { func (suite *CollectionManagerSuite) SetupTest() {
@ -79,6 +92,7 @@ func (suite *CollectionManagerSuite) SetupTest() {
suite.Require().NoError(err) suite.Require().NoError(err)
suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue()) suite.kv = etcdkv.NewEtcdKV(cli, config.MetaRootPath.GetValue())
suite.store = NewMetaStore(suite.kv) suite.store = NewMetaStore(suite.kv)
suite.broker = NewMockBroker(suite.T())
suite.mgr = NewCollectionManager(suite.store) suite.mgr = NewCollectionManager(suite.store)
suite.loadAll() suite.loadAll()
@ -94,18 +108,18 @@ func (suite *CollectionManagerSuite) TestGetProperty() {
for i, collection := range suite.collections { for i, collection := range suite.collections {
loadType := mgr.GetLoadType(collection) loadType := mgr.GetLoadType(collection)
replicaNumber := mgr.GetReplicaNumber(collection) replicaNumber := mgr.GetReplicaNumber(collection)
percentage := mgr.GetLoadPercentage(collection) percentage := mgr.GetCurrentLoadPercentage(collection)
exist := mgr.Exist(collection) exist := mgr.Exist(collection)
suite.Equal(suite.loadTypes[i], loadType) suite.Equal(suite.loadTypes[i], loadType)
suite.Equal(suite.replicaNumber[i], replicaNumber) suite.Equal(suite.replicaNumber[i], replicaNumber)
suite.Equal(suite.loadPercentage[i], percentage) suite.Equal(suite.colLoadPercent[i], percentage)
suite.True(exist) suite.True(exist)
} }
invalidCollection := -1 invalidCollection := -1
loadType := mgr.GetLoadType(int64(invalidCollection)) loadType := mgr.GetLoadType(int64(invalidCollection))
replicaNumber := mgr.GetReplicaNumber(int64(invalidCollection)) replicaNumber := mgr.GetReplicaNumber(int64(invalidCollection))
percentage := mgr.GetLoadPercentage(int64(invalidCollection)) percentage := mgr.GetCurrentLoadPercentage(int64(invalidCollection))
exist := mgr.Exist(int64(invalidCollection)) exist := mgr.Exist(int64(invalidCollection))
suite.Equal(querypb.LoadType_UnKnownType, loadType) suite.Equal(querypb.LoadType_UnKnownType, loadType)
suite.EqualValues(-1, replicaNumber) suite.EqualValues(-1, replicaNumber)
@ -113,33 +127,45 @@ func (suite *CollectionManagerSuite) TestGetProperty() {
suite.False(exist) suite.False(exist)
} }
func (suite *CollectionManagerSuite) TestGet() { func (suite *CollectionManagerSuite) TestPut() {
mgr := suite.mgr suite.releaseAll()
// test put collection with partitions
allCollections := mgr.GetAllCollections() for i, collection := range suite.collections {
allPartitions := mgr.GetAllPartitions() status := querypb.LoadStatus_Loaded
for i, collectionID := range suite.collections { if suite.colLoadPercent[i] < 100 {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection { status = querypb.LoadStatus_Loading
collection := mgr.GetCollection(collectionID)
suite.Equal(collectionID, collection.GetCollectionID())
suite.Contains(allCollections, collection)
} else {
partitions := mgr.GetPartitionsByCollection(collectionID)
suite.Len(partitions, len(suite.partitions[collectionID]))
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
suite.Equal(collectionID, partition.GetCollectionID())
suite.Equal(partitionID, partition.GetPartitionID())
suite.Contains(partitions, partition)
suite.Contains(allPartitions, partition)
}
} }
}
all := mgr.GetAll() col := &Collection{
sort.Slice(all, func(i, j int) bool { return all[i] < all[j] }) CollectionLoadInfo: &querypb.CollectionLoadInfo{
suite.Equal(suite.collections, all) CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i],
Status: status,
LoadType: suite.loadTypes[i],
},
LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(),
}
partitions := lo.Map(suite.partitions[collection], func(partition int64, j int) *Partition {
return &Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
ReplicaNumber: suite.replicaNumber[i],
Status: status,
},
LoadPercentage: suite.parLoadPercent[collection][j],
CreatedAt: time.Now(),
}
})
err := suite.mgr.PutCollection(col, partitions...)
suite.NoError(err)
}
suite.checkLoadResult()
}
func (suite *CollectionManagerSuite) TestGet() {
suite.checkLoadResult()
} }
func (suite *CollectionManagerSuite) TestUpdate() { func (suite *CollectionManagerSuite) TestUpdate() {
@ -177,7 +203,7 @@ func (suite *CollectionManagerSuite) TestUpdate() {
} }
suite.clearMemory() suite.clearMemory()
err := mgr.Recover() err := mgr.Recover(suite.broker)
suite.NoError(err) suite.NoError(err)
collections = mgr.GetAllCollections() collections = mgr.GetAllCollections()
partitions = mgr.GetAllPartitions() partitions = mgr.GetAllPartitions()
@ -215,7 +241,7 @@ func (suite *CollectionManagerSuite) TestRemove() {
} }
// Make sure the removes applied to meta store // Make sure the removes applied to meta store
err := mgr.Recover() err := mgr.Recover(suite.broker)
suite.NoError(err) suite.NoError(err)
for i, collectionID := range suite.collections { for i, collectionID := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection { if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
@ -237,37 +263,50 @@ func (suite *CollectionManagerSuite) TestRemove() {
suite.Empty(partitions) suite.Empty(partitions)
} }
} }
// remove collection would release its partitions also
suite.releaseAll()
suite.loadAll()
for _, collectionID := range suite.collections {
err := mgr.RemoveCollection(collectionID)
suite.NoError(err)
err = mgr.Recover(suite.broker)
suite.NoError(err)
collection := mgr.GetCollection(collectionID)
suite.Nil(collection)
partitions := mgr.GetPartitionsByCollection(collectionID)
suite.Empty(partitions)
}
} }
func (suite *CollectionManagerSuite) TestRecover() { func (suite *CollectionManagerSuite) TestRecover() {
mgr := suite.mgr mgr := suite.mgr
suite.clearMemory() suite.clearMemory()
err := mgr.Recover() err := mgr.Recover(suite.broker)
suite.NoError(err) suite.NoError(err)
for i, collection := range suite.collections { for i, collection := range suite.collections {
exist := suite.loadPercentage[i] == 100 exist := suite.colLoadPercent[i] == 100
suite.Equal(exist, mgr.Exist(collection)) suite.Equal(exist, mgr.Exist(collection))
} }
} }
func (suite *CollectionManagerSuite) loadAll() { func (suite *CollectionManagerSuite) TestUpgradeRecover() {
suite.releaseAll()
mgr := suite.mgr mgr := suite.mgr
// put old version of collections and partitions
for i, collection := range suite.collections { for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded status := querypb.LoadStatus_Loaded
if suite.loadPercentage[i] < 100 {
status = querypb.LoadStatus_Loading
}
if suite.loadTypes[i] == querypb.LoadType_LoadCollection { if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
mgr.PutCollection(&Collection{ mgr.PutCollection(&Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection, CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i], ReplicaNumber: suite.replicaNumber[i],
Status: status, Status: status,
LoadType: querypb.LoadType_UnKnownType, // old version's collection didn't set loadType
}, },
LoadPercentage: suite.loadPercentage[i], LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} else { } else {
@ -279,12 +318,92 @@ func (suite *CollectionManagerSuite) loadAll() {
ReplicaNumber: suite.replicaNumber[i], ReplicaNumber: suite.replicaNumber[i],
Status: status, Status: status,
}, },
LoadPercentage: suite.loadPercentage[i], LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} }
} }
} }
// set expectations
for i, collection := range suite.collections {
if suite.loadTypes[i] == querypb.LoadType_LoadCollection {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
}
}
// do recovery
suite.clearMemory()
err := mgr.Recover(suite.broker)
suite.NoError(err)
suite.checkLoadResult()
}
func (suite *CollectionManagerSuite) loadAll() {
mgr := suite.mgr
for i, collection := range suite.collections {
status := querypb.LoadStatus_Loaded
if suite.colLoadPercent[i] < 100 {
status = querypb.LoadStatus_Loading
}
mgr.PutCollection(&Collection{
CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionID: collection,
ReplicaNumber: suite.replicaNumber[i],
Status: status,
LoadType: suite.loadTypes[i],
},
LoadPercentage: suite.colLoadPercent[i],
CreatedAt: time.Now(),
})
for j, partition := range suite.partitions[collection] {
mgr.PutPartition(&Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
Status: status,
},
LoadPercentage: suite.parLoadPercent[collection][j],
CreatedAt: time.Now(),
})
}
}
}
func (suite *CollectionManagerSuite) checkLoadResult() {
mgr := suite.mgr
allCollections := mgr.GetAllCollections()
allPartitions := mgr.GetAllPartitions()
for _, collectionID := range suite.collections {
collection := mgr.GetCollection(collectionID)
suite.Equal(collectionID, collection.GetCollectionID())
suite.Contains(allCollections, collection)
partitions := mgr.GetPartitionsByCollection(collectionID)
suite.Len(partitions, len(suite.partitions[collectionID]))
for _, partitionID := range suite.partitions[collectionID] {
partition := mgr.GetPartition(partitionID)
suite.Equal(collectionID, partition.GetCollectionID())
suite.Equal(partitionID, partition.GetPartitionID())
suite.Contains(partitions, partition)
suite.Contains(allPartitions, partition)
}
}
all := mgr.GetAll()
sort.Slice(all, func(i, j int) bool { return all[i] < all[j] })
suite.Equal(suite.collections, all)
}
func (suite *CollectionManagerSuite) releaseAll() {
for _, collection := range suite.collections {
err := suite.mgr.RemoveCollection(collection)
suite.NoError(err)
}
} }
func (suite *CollectionManagerSuite) clearMemory() { func (suite *CollectionManagerSuite) clearMemory() {

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package meta package meta
@ -200,13 +200,13 @@ func (_c *MockStore_GetResourceGroups_Call) Return(_a0 []*querypb.ResourceGroup,
return _c return _c
} }
// ReleaseCollection provides a mock function with given fields: id // ReleaseCollection provides a mock function with given fields: collection
func (_m *MockStore) ReleaseCollection(id int64) error { func (_m *MockStore) ReleaseCollection(collection int64) error {
ret := _m.Called(id) ret := _m.Called(collection)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(int64) error); ok { if rf, ok := ret.Get(0).(func(int64) error); ok {
r0 = rf(id) r0 = rf(collection)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -220,12 +220,12 @@ type MockStore_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - id int64 // - collection int64
func (_e *MockStore_Expecter) ReleaseCollection(id interface{}) *MockStore_ReleaseCollection_Call { func (_e *MockStore_Expecter) ReleaseCollection(collection interface{}) *MockStore_ReleaseCollection_Call {
return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", id)} return &MockStore_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", collection)}
} }
func (_c *MockStore_ReleaseCollection_Call) Run(run func(id int64)) *MockStore_ReleaseCollection_Call { func (_c *MockStore_ReleaseCollection_Call) Run(run func(collection int64)) *MockStore_ReleaseCollection_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(int64)) run(args[0].(int64))
}) })
@ -264,8 +264,8 @@ type MockStore_ReleasePartition_Call struct {
} }
// ReleasePartition is a helper method to define mock.On call // ReleasePartition is a helper method to define mock.On call
// - collection int64 // - collection int64
// - partitions ...int64 // - partitions ...int64
func (_e *MockStore_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *MockStore_ReleasePartition_Call { func (_e *MockStore_Expecter) ReleasePartition(collection interface{}, partitions ...interface{}) *MockStore_ReleasePartition_Call {
return &MockStore_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition", return &MockStore_ReleasePartition_Call{Call: _e.mock.On("ReleasePartition",
append([]interface{}{collection}, partitions...)...)} append([]interface{}{collection}, partitions...)...)}
@ -309,8 +309,8 @@ type MockStore_ReleaseReplica_Call struct {
} }
// ReleaseReplica is a helper method to define mock.On call // ReleaseReplica is a helper method to define mock.On call
// - collection int64 // - collection int64
// - replica int64 // - replica int64
func (_e *MockStore_Expecter) ReleaseReplica(collection interface{}, replica interface{}) *MockStore_ReleaseReplica_Call { func (_e *MockStore_Expecter) ReleaseReplica(collection interface{}, replica interface{}) *MockStore_ReleaseReplica_Call {
return &MockStore_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica", collection, replica)} return &MockStore_ReleaseReplica_Call{Call: _e.mock.On("ReleaseReplica", collection, replica)}
} }
@ -347,7 +347,7 @@ type MockStore_ReleaseReplicas_Call struct {
} }
// ReleaseReplicas is a helper method to define mock.On call // ReleaseReplicas is a helper method to define mock.On call
// - collectionID int64 // - collectionID int64
func (_e *MockStore_Expecter) ReleaseReplicas(collectionID interface{}) *MockStore_ReleaseReplicas_Call { func (_e *MockStore_Expecter) ReleaseReplicas(collectionID interface{}) *MockStore_ReleaseReplicas_Call {
return &MockStore_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)} return &MockStore_ReleaseReplicas_Call{Call: _e.mock.On("ReleaseReplicas", collectionID)}
} }
@ -384,7 +384,7 @@ type MockStore_RemoveResourceGroup_Call struct {
} }
// RemoveResourceGroup is a helper method to define mock.On call // RemoveResourceGroup is a helper method to define mock.On call
// - rgName string // - rgName string
func (_e *MockStore_Expecter) RemoveResourceGroup(rgName interface{}) *MockStore_RemoveResourceGroup_Call { func (_e *MockStore_Expecter) RemoveResourceGroup(rgName interface{}) *MockStore_RemoveResourceGroup_Call {
return &MockStore_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)} return &MockStore_RemoveResourceGroup_Call{Call: _e.mock.On("RemoveResourceGroup", rgName)}
} }
@ -401,13 +401,20 @@ func (_c *MockStore_RemoveResourceGroup_Call) Return(_a0 error) *MockStore_Remov
return _c return _c
} }
// SaveCollection provides a mock function with given fields: info // SaveCollection provides a mock function with given fields: collection, partitions
func (_m *MockStore) SaveCollection(info *querypb.CollectionLoadInfo) error { func (_m *MockStore) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
ret := _m.Called(info) _va := make([]interface{}, len(partitions))
for _i := range partitions {
_va[_i] = partitions[_i]
}
var _ca []interface{}
_ca = append(_ca, collection)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
var r0 error var r0 error
if rf, ok := ret.Get(0).(func(*querypb.CollectionLoadInfo) error); ok { if rf, ok := ret.Get(0).(func(*querypb.CollectionLoadInfo, ...*querypb.PartitionLoadInfo) error); ok {
r0 = rf(info) r0 = rf(collection, partitions...)
} else { } else {
r0 = ret.Error(0) r0 = ret.Error(0)
} }
@ -421,14 +428,22 @@ type MockStore_SaveCollection_Call struct {
} }
// SaveCollection is a helper method to define mock.On call // SaveCollection is a helper method to define mock.On call
// - info *querypb.CollectionLoadInfo // - collection *querypb.CollectionLoadInfo
func (_e *MockStore_Expecter) SaveCollection(info interface{}) *MockStore_SaveCollection_Call { // - partitions ...*querypb.PartitionLoadInfo
return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection", info)} func (_e *MockStore_Expecter) SaveCollection(collection interface{}, partitions ...interface{}) *MockStore_SaveCollection_Call {
return &MockStore_SaveCollection_Call{Call: _e.mock.On("SaveCollection",
append([]interface{}{collection}, partitions...)...)}
} }
func (_c *MockStore_SaveCollection_Call) Run(run func(info *querypb.CollectionLoadInfo)) *MockStore_SaveCollection_Call { func (_c *MockStore_SaveCollection_Call) Run(run func(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo)) *MockStore_SaveCollection_Call {
_c.Call.Run(func(args mock.Arguments) { _c.Call.Run(func(args mock.Arguments) {
run(args[0].(*querypb.CollectionLoadInfo)) variadicArgs := make([]*querypb.PartitionLoadInfo, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(*querypb.PartitionLoadInfo)
}
}
run(args[0].(*querypb.CollectionLoadInfo), variadicArgs...)
}) })
return _c return _c
} }
@ -464,7 +479,7 @@ type MockStore_SavePartition_Call struct {
} }
// SavePartition is a helper method to define mock.On call // SavePartition is a helper method to define mock.On call
// - info ...*querypb.PartitionLoadInfo // - info ...*querypb.PartitionLoadInfo
func (_e *MockStore_Expecter) SavePartition(info ...interface{}) *MockStore_SavePartition_Call { func (_e *MockStore_Expecter) SavePartition(info ...interface{}) *MockStore_SavePartition_Call {
return &MockStore_SavePartition_Call{Call: _e.mock.On("SavePartition", return &MockStore_SavePartition_Call{Call: _e.mock.On("SavePartition",
append([]interface{}{}, info...)...)} append([]interface{}{}, info...)...)}
@ -508,7 +523,7 @@ type MockStore_SaveReplica_Call struct {
} }
// SaveReplica is a helper method to define mock.On call // SaveReplica is a helper method to define mock.On call
// - replica *querypb.Replica // - replica *querypb.Replica
func (_e *MockStore_Expecter) SaveReplica(replica interface{}) *MockStore_SaveReplica_Call { func (_e *MockStore_Expecter) SaveReplica(replica interface{}) *MockStore_SaveReplica_Call {
return &MockStore_SaveReplica_Call{Call: _e.mock.On("SaveReplica", replica)} return &MockStore_SaveReplica_Call{Call: _e.mock.On("SaveReplica", replica)}
} }
@ -551,7 +566,7 @@ type MockStore_SaveResourceGroup_Call struct {
} }
// SaveResourceGroup is a helper method to define mock.On call // SaveResourceGroup is a helper method to define mock.On call
// - rgs ...*querypb.ResourceGroup // - rgs ...*querypb.ResourceGroup
func (_e *MockStore_Expecter) SaveResourceGroup(rgs ...interface{}) *MockStore_SaveResourceGroup_Call { func (_e *MockStore_Expecter) SaveResourceGroup(rgs ...interface{}) *MockStore_SaveResourceGroup_Call {
return &MockStore_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup", return &MockStore_SaveResourceGroup_Call{Call: _e.mock.On("SaveResourceGroup",
append([]interface{}{}, rgs...)...)} append([]interface{}{}, rgs...)...)}

View File

@ -61,13 +61,23 @@ func NewMetaStore(cli kv.MetaKv) metaStore {
} }
} }
func (s metaStore) SaveCollection(info *querypb.CollectionLoadInfo) error { func (s metaStore) SaveCollection(collection *querypb.CollectionLoadInfo, partitions ...*querypb.PartitionLoadInfo) error {
k := encodeCollectionLoadInfoKey(info.GetCollectionID()) k := encodeCollectionLoadInfoKey(collection.GetCollectionID())
v, err := proto.Marshal(info) v, err := proto.Marshal(collection)
if err != nil { if err != nil {
return err return err
} }
return s.cli.Save(k, string(v)) kvs := make(map[string]string)
for _, partition := range partitions {
key := encodePartitionLoadInfoKey(partition.GetCollectionID(), partition.GetPartitionID())
value, err := proto.Marshal(partition)
if err != nil {
return err
}
kvs[key] = string(value)
}
kvs[k] = string(v)
return s.cli.MultiSave(kvs)
} }
func (s metaStore) SavePartition(info ...*querypb.PartitionLoadInfo) error { func (s metaStore) SavePartition(info ...*querypb.PartitionLoadInfo) error {
@ -211,9 +221,27 @@ func (s metaStore) GetResourceGroups() ([]*querypb.ResourceGroup, error) {
return ret, nil return ret, nil
} }
func (s metaStore) ReleaseCollection(id int64) error { func (s metaStore) ReleaseCollection(collection int64) error {
k := encodeCollectionLoadInfoKey(id) // obtain partitions of this collection
return s.cli.Remove(k) _, values, err := s.cli.LoadWithPrefix(fmt.Sprintf("%s/%d", PartitionLoadInfoPrefix, collection))
if err != nil {
return err
}
partitions := make([]*querypb.PartitionLoadInfo, 0)
for _, v := range values {
info := querypb.PartitionLoadInfo{}
if err = proto.Unmarshal([]byte(v), &info); err != nil {
return err
}
partitions = append(partitions, &info)
}
// remove collection and obtained partitions
keys := lo.Map(partitions, func(partition *querypb.PartitionLoadInfo, _ int) string {
return encodePartitionLoadInfoKey(collection, partition.GetPartitionID())
})
k := encodeCollectionLoadInfoKey(collection)
keys = append(keys, k)
return s.cli.MultiRemove(keys)
} }
func (s metaStore) ReleasePartition(collection int64, partitions ...int64) error { func (s metaStore) ReleasePartition(collection int64, partitions ...int64) error {

View File

@ -81,6 +81,39 @@ func (suite *StoreTestSuite) TestCollection() {
suite.Len(collections, 1) suite.Len(collections, 1)
} }
func (suite *StoreTestSuite) TestCollectionWithPartition() {
suite.store.SaveCollection(&querypb.CollectionLoadInfo{
CollectionID: 1,
})
suite.store.SaveCollection(&querypb.CollectionLoadInfo{
CollectionID: 2,
}, &querypb.PartitionLoadInfo{
CollectionID: 2,
PartitionID: 102,
})
suite.store.SaveCollection(&querypb.CollectionLoadInfo{
CollectionID: 3,
}, &querypb.PartitionLoadInfo{
CollectionID: 3,
PartitionID: 103,
})
suite.store.ReleaseCollection(1)
suite.store.ReleaseCollection(2)
collections, err := suite.store.GetCollections()
suite.NoError(err)
suite.Len(collections, 1)
suite.Equal(int64(3), collections[0].GetCollectionID())
partitions, err := suite.store.GetPartitions()
suite.NoError(err)
suite.Len(partitions, 1)
suite.Len(partitions[int64(3)], 1)
suite.Equal(int64(103), partitions[int64(3)][0].GetPartitionID())
}
func (suite *StoreTestSuite) TestPartition() { func (suite *StoreTestSuite) TestPartition() {
suite.store.SavePartition(&querypb.PartitionLoadInfo{ suite.store.SavePartition(&querypb.PartitionLoadInfo{
PartitionID: 1, PartitionID: 1,

View File

@ -106,22 +106,10 @@ func (mgr *TargetManager) UpdateCollectionNextTarget(collectionID int64) error {
mgr.rwMutex.Lock() mgr.rwMutex.Lock()
defer mgr.rwMutex.Unlock() defer mgr.rwMutex.Unlock()
partitionIDs := make([]int64, 0) partitions := mgr.meta.GetPartitionsByCollection(collectionID)
collection := mgr.meta.GetCollection(collectionID) partitionIDs := lo.Map(partitions, func(partition *Partition, i int) int64 {
if collection != nil { return partition.PartitionID
var err error })
partitionIDs, err = mgr.broker.GetPartitions(context.Background(), collectionID)
if err != nil {
return err
}
} else {
partitions := mgr.meta.GetPartitionsByCollection(collectionID)
if partitions != nil {
partitionIDs = lo.Map(partitions, func(partition *Partition, i int) int64 {
return partition.PartitionID
})
}
}
return mgr.updateCollectionNextTarget(collectionID, partitionIDs...) return mgr.updateCollectionNextTarget(collectionID, partitionIDs...)
} }
@ -146,14 +134,27 @@ func (mgr *TargetManager) updateCollectionNextTarget(collectionID int64, partiti
return nil return nil
} }
func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, partitionIDs ...int64) (*CollectionTarget, error) { func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, chosenPartitionIDs ...int64) (*CollectionTarget, error) {
log.Info("start to pull next targets for partition", log.Info("start to pull next targets for partition",
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64s("partitionIDs", partitionIDs)) zap.Int64s("chosenPartitionIDs", chosenPartitionIDs))
channelInfos := make(map[string][]*datapb.VchannelInfo) channelInfos := make(map[string][]*datapb.VchannelInfo)
segments := make(map[int64]*datapb.SegmentInfo, 0) segments := make(map[int64]*datapb.SegmentInfo, 0)
for _, partitionID := range partitionIDs { dmChannels := make(map[string]*DmChannel)
if len(chosenPartitionIDs) == 0 {
return NewCollectionTarget(segments, dmChannels), nil
}
fullPartitions, err := broker.GetPartitions(context.Background(), collectionID)
if err != nil {
return nil, err
}
// we should pull `channel targets` from all partitions because QueryNodes need to load
// the complete growing segments. And we should pull `segments targets` only from the chosen partitions.
for _, partitionID := range fullPartitions {
log.Debug("get recovery info...", log.Debug("get recovery info...",
zap.Int64("collectionID", collectionID), zap.Int64("collectionID", collectionID),
zap.Int64("partitionID", partitionID)) zap.Int64("partitionID", partitionID))
@ -161,7 +162,12 @@ func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, part
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, info := range vChannelInfos {
channelInfos[info.GetChannelName()] = append(channelInfos[info.GetChannelName()], info)
}
if !lo.Contains(chosenPartitionIDs, partitionID) {
continue
}
for _, binlog := range binlogs { for _, binlog := range binlogs {
segments[binlog.GetSegmentID()] = &datapb.SegmentInfo{ segments[binlog.GetSegmentID()] = &datapb.SegmentInfo{
ID: binlog.GetSegmentID(), ID: binlog.GetSegmentID(),
@ -174,18 +180,12 @@ func (mgr *TargetManager) PullNextTarget(broker Broker, collectionID int64, part
Deltalogs: binlog.GetDeltalogs(), Deltalogs: binlog.GetDeltalogs(),
} }
} }
for _, info := range vChannelInfos {
channelInfos[info.GetChannelName()] = append(channelInfos[info.GetChannelName()], info)
}
} }
dmChannels := make(map[string]*DmChannel)
for _, infos := range channelInfos { for _, infos := range channelInfos {
merged := mgr.mergeDmChannelInfo(infos) merged := mgr.mergeDmChannelInfo(infos)
dmChannels[merged.GetChannelName()] = merged dmChannels[merged.GetChannelName()] = merged
} }
return NewCollectionTarget(segments, dmChannels), nil return NewCollectionTarget(segments, dmChannels), nil
} }

View File

@ -126,7 +126,7 @@ func (suite *TargetManagerSuite) SetupTest() {
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments, nil)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
suite.mgr.UpdateCollectionNextTargetWithPartitions(collection, suite.partitions[collection]...) suite.mgr.UpdateCollectionNextTargetWithPartitions(collection, suite.partitions[collection]...)
} }
} }
@ -192,6 +192,7 @@ func (suite *TargetManagerSuite) TestUpdateNextTarget() {
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, collectionID).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collectionID, int64(1)).Return(nextTargetChannels, nextTargetSegments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collectionID, int64(1)).Return(nextTargetChannels, nextTargetSegments, nil)
suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID, int64(1)) suite.mgr.UpdateCollectionNextTargetWithPartitions(collectionID, int64(1))
suite.assertSegments([]int64{11, 12}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget)) suite.assertSegments([]int64{11, 12}, suite.mgr.GetHistoricalSegmentsByCollection(collectionID, NextTarget))

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package mocks package mocks
@ -58,8 +58,8 @@ type MockQueryNodeServer_GetComponentStates_Call struct {
} }
// GetComponentStates is a helper method to define mock.On call // GetComponentStates is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *milvuspb.GetComponentStatesRequest // - _a1 *milvuspb.GetComponentStatesRequest
func (_e *MockQueryNodeServer_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetComponentStates_Call { func (_e *MockQueryNodeServer_Expecter) GetComponentStates(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetComponentStates_Call {
return &MockQueryNodeServer_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)} return &MockQueryNodeServer_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", _a0, _a1)}
} }
@ -105,8 +105,8 @@ type MockQueryNodeServer_GetDataDistribution_Call struct {
} }
// GetDataDistribution is a helper method to define mock.On call // GetDataDistribution is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetDataDistributionRequest // - _a1 *querypb.GetDataDistributionRequest
func (_e *MockQueryNodeServer_Expecter) GetDataDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetDataDistribution_Call { func (_e *MockQueryNodeServer_Expecter) GetDataDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetDataDistribution_Call {
return &MockQueryNodeServer_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", _a0, _a1)} return &MockQueryNodeServer_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", _a0, _a1)}
} }
@ -152,8 +152,8 @@ type MockQueryNodeServer_GetMetrics_Call struct {
} }
// GetMetrics is a helper method to define mock.On call // GetMetrics is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *milvuspb.GetMetricsRequest // - _a1 *milvuspb.GetMetricsRequest
func (_e *MockQueryNodeServer_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetMetrics_Call { func (_e *MockQueryNodeServer_Expecter) GetMetrics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetMetrics_Call {
return &MockQueryNodeServer_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)} return &MockQueryNodeServer_GetMetrics_Call{Call: _e.mock.On("GetMetrics", _a0, _a1)}
} }
@ -199,8 +199,8 @@ type MockQueryNodeServer_GetSegmentInfo_Call struct {
} }
// GetSegmentInfo is a helper method to define mock.On call // GetSegmentInfo is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetSegmentInfoRequest // - _a1 *querypb.GetSegmentInfoRequest
func (_e *MockQueryNodeServer_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetSegmentInfo_Call { func (_e *MockQueryNodeServer_Expecter) GetSegmentInfo(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetSegmentInfo_Call {
return &MockQueryNodeServer_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)} return &MockQueryNodeServer_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", _a0, _a1)}
} }
@ -246,8 +246,8 @@ type MockQueryNodeServer_GetStatistics_Call struct {
} }
// GetStatistics is a helper method to define mock.On call // GetStatistics is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.GetStatisticsRequest // - _a1 *querypb.GetStatisticsRequest
func (_e *MockQueryNodeServer_Expecter) GetStatistics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatistics_Call { func (_e *MockQueryNodeServer_Expecter) GetStatistics(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatistics_Call {
return &MockQueryNodeServer_GetStatistics_Call{Call: _e.mock.On("GetStatistics", _a0, _a1)} return &MockQueryNodeServer_GetStatistics_Call{Call: _e.mock.On("GetStatistics", _a0, _a1)}
} }
@ -293,8 +293,8 @@ type MockQueryNodeServer_GetStatisticsChannel_Call struct {
} }
// GetStatisticsChannel is a helper method to define mock.On call // GetStatisticsChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.GetStatisticsChannelRequest // - _a1 *internalpb.GetStatisticsChannelRequest
func (_e *MockQueryNodeServer_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatisticsChannel_Call { func (_e *MockQueryNodeServer_Expecter) GetStatisticsChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetStatisticsChannel_Call {
return &MockQueryNodeServer_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)} return &MockQueryNodeServer_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", _a0, _a1)}
} }
@ -340,8 +340,8 @@ type MockQueryNodeServer_GetTimeTickChannel_Call struct {
} }
// GetTimeTickChannel is a helper method to define mock.On call // GetTimeTickChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.GetTimeTickChannelRequest // - _a1 *internalpb.GetTimeTickChannelRequest
func (_e *MockQueryNodeServer_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetTimeTickChannel_Call { func (_e *MockQueryNodeServer_Expecter) GetTimeTickChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_GetTimeTickChannel_Call {
return &MockQueryNodeServer_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)} return &MockQueryNodeServer_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", _a0, _a1)}
} }
@ -358,6 +358,53 @@ func (_c *MockQueryNodeServer_GetTimeTickChannel_Call) Return(_a0 *milvuspb.Stri
return _c return _c
} }
// LoadPartitions provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadPartitions(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, *querypb.LoadPartitionsRequest) *commonpb.Status); ok {
r0 = rf(_a0, _a1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.LoadPartitionsRequest) error); ok {
r1 = rf(_a0, _a1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryNodeServer_LoadPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadPartitions'
type MockQueryNodeServer_LoadPartitions_Call struct {
*mock.Call
}
// LoadPartitions is a helper method to define mock.On call
// - _a0 context.Context
// - _a1 *querypb.LoadPartitionsRequest
func (_e *MockQueryNodeServer_Expecter) LoadPartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadPartitions_Call {
return &MockQueryNodeServer_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", _a0, _a1)}
}
func (_c *MockQueryNodeServer_LoadPartitions_Call) Run(run func(_a0 context.Context, _a1 *querypb.LoadPartitionsRequest)) *MockQueryNodeServer_LoadPartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.LoadPartitionsRequest))
})
return _c
}
func (_c *MockQueryNodeServer_LoadPartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryNodeServer_LoadPartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// LoadSegments provides a mock function with given fields: _a0, _a1 // LoadSegments provides a mock function with given fields: _a0, _a1
func (_m *MockQueryNodeServer) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { func (_m *MockQueryNodeServer) LoadSegments(_a0 context.Context, _a1 *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(_a0, _a1) ret := _m.Called(_a0, _a1)
@ -387,8 +434,8 @@ type MockQueryNodeServer_LoadSegments_Call struct {
} }
// LoadSegments is a helper method to define mock.On call // LoadSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.LoadSegmentsRequest // - _a1 *querypb.LoadSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadSegments_Call { func (_e *MockQueryNodeServer_Expecter) LoadSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_LoadSegments_Call {
return &MockQueryNodeServer_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)} return &MockQueryNodeServer_LoadSegments_Call{Call: _e.mock.On("LoadSegments", _a0, _a1)}
} }
@ -434,8 +481,8 @@ type MockQueryNodeServer_Query_Call struct {
} }
// Query is a helper method to define mock.On call // Query is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.QueryRequest // - _a1 *querypb.QueryRequest
func (_e *MockQueryNodeServer_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Query_Call { func (_e *MockQueryNodeServer_Expecter) Query(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Query_Call {
return &MockQueryNodeServer_Query_Call{Call: _e.mock.On("Query", _a0, _a1)} return &MockQueryNodeServer_Query_Call{Call: _e.mock.On("Query", _a0, _a1)}
} }
@ -481,8 +528,8 @@ type MockQueryNodeServer_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleaseCollectionRequest // - _a1 *querypb.ReleaseCollectionRequest
func (_e *MockQueryNodeServer_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseCollection_Call { func (_e *MockQueryNodeServer_Expecter) ReleaseCollection(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseCollection_Call {
return &MockQueryNodeServer_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)} return &MockQueryNodeServer_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", _a0, _a1)}
} }
@ -528,8 +575,8 @@ type MockQueryNodeServer_ReleasePartitions_Call struct {
} }
// ReleasePartitions is a helper method to define mock.On call // ReleasePartitions is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleasePartitionsRequest // - _a1 *querypb.ReleasePartitionsRequest
func (_e *MockQueryNodeServer_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleasePartitions_Call { func (_e *MockQueryNodeServer_Expecter) ReleasePartitions(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleasePartitions_Call {
return &MockQueryNodeServer_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)} return &MockQueryNodeServer_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", _a0, _a1)}
} }
@ -575,8 +622,8 @@ type MockQueryNodeServer_ReleaseSegments_Call struct {
} }
// ReleaseSegments is a helper method to define mock.On call // ReleaseSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.ReleaseSegmentsRequest // - _a1 *querypb.ReleaseSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseSegments_Call { func (_e *MockQueryNodeServer_Expecter) ReleaseSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ReleaseSegments_Call {
return &MockQueryNodeServer_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)} return &MockQueryNodeServer_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", _a0, _a1)}
} }
@ -622,8 +669,8 @@ type MockQueryNodeServer_Search_Call struct {
} }
// Search is a helper method to define mock.On call // Search is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SearchRequest // - _a1 *querypb.SearchRequest
func (_e *MockQueryNodeServer_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Search_Call { func (_e *MockQueryNodeServer_Expecter) Search(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_Search_Call {
return &MockQueryNodeServer_Search_Call{Call: _e.mock.On("Search", _a0, _a1)} return &MockQueryNodeServer_Search_Call{Call: _e.mock.On("Search", _a0, _a1)}
} }
@ -669,8 +716,8 @@ type MockQueryNodeServer_ShowConfigurations_Call struct {
} }
// ShowConfigurations is a helper method to define mock.On call // ShowConfigurations is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *internalpb.ShowConfigurationsRequest // - _a1 *internalpb.ShowConfigurationsRequest
func (_e *MockQueryNodeServer_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ShowConfigurations_Call { func (_e *MockQueryNodeServer_Expecter) ShowConfigurations(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_ShowConfigurations_Call {
return &MockQueryNodeServer_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)} return &MockQueryNodeServer_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", _a0, _a1)}
} }
@ -716,8 +763,8 @@ type MockQueryNodeServer_SyncDistribution_Call struct {
} }
// SyncDistribution is a helper method to define mock.On call // SyncDistribution is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SyncDistributionRequest // - _a1 *querypb.SyncDistributionRequest
func (_e *MockQueryNodeServer_Expecter) SyncDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncDistribution_Call { func (_e *MockQueryNodeServer_Expecter) SyncDistribution(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncDistribution_Call {
return &MockQueryNodeServer_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", _a0, _a1)} return &MockQueryNodeServer_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", _a0, _a1)}
} }
@ -763,8 +810,8 @@ type MockQueryNodeServer_SyncReplicaSegments_Call struct {
} }
// SyncReplicaSegments is a helper method to define mock.On call // SyncReplicaSegments is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.SyncReplicaSegmentsRequest // - _a1 *querypb.SyncReplicaSegmentsRequest
func (_e *MockQueryNodeServer_Expecter) SyncReplicaSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncReplicaSegments_Call { func (_e *MockQueryNodeServer_Expecter) SyncReplicaSegments(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_SyncReplicaSegments_Call {
return &MockQueryNodeServer_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", _a0, _a1)} return &MockQueryNodeServer_SyncReplicaSegments_Call{Call: _e.mock.On("SyncReplicaSegments", _a0, _a1)}
} }
@ -810,8 +857,8 @@ type MockQueryNodeServer_UnsubDmChannel_Call struct {
} }
// UnsubDmChannel is a helper method to define mock.On call // UnsubDmChannel is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.UnsubDmChannelRequest // - _a1 *querypb.UnsubDmChannelRequest
func (_e *MockQueryNodeServer_Expecter) UnsubDmChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_UnsubDmChannel_Call { func (_e *MockQueryNodeServer_Expecter) UnsubDmChannel(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_UnsubDmChannel_Call {
return &MockQueryNodeServer_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", _a0, _a1)} return &MockQueryNodeServer_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", _a0, _a1)}
} }
@ -857,8 +904,8 @@ type MockQueryNodeServer_WatchDmChannels_Call struct {
} }
// WatchDmChannels is a helper method to define mock.On call // WatchDmChannels is a helper method to define mock.On call
// - _a0 context.Context // - _a0 context.Context
// - _a1 *querypb.WatchDmChannelsRequest // - _a1 *querypb.WatchDmChannelsRequest
func (_e *MockQueryNodeServer_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_WatchDmChannels_Call { func (_e *MockQueryNodeServer_Expecter) WatchDmChannels(_a0 interface{}, _a1 interface{}) *MockQueryNodeServer_WatchDmChannels_Call {
return &MockQueryNodeServer_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)} return &MockQueryNodeServer_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", _a0, _a1)}
} }

View File

@ -33,12 +33,11 @@ import (
type CollectionObserver struct { type CollectionObserver struct {
stopCh chan struct{} stopCh chan struct{}
dist *meta.DistributionManager dist *meta.DistributionManager
meta *meta.Meta meta *meta.Meta
targetMgr *meta.TargetManager targetMgr *meta.TargetManager
targetObserver *TargetObserver targetObserver *TargetObserver
collectionLoadedCount map[int64]int partitionLoadedCount map[int64]int
partitionLoadedCount map[int64]int
stopOnce sync.Once stopOnce sync.Once
} }
@ -50,13 +49,12 @@ func NewCollectionObserver(
targetObserver *TargetObserver, targetObserver *TargetObserver,
) *CollectionObserver { ) *CollectionObserver {
return &CollectionObserver{ return &CollectionObserver{
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
dist: dist, dist: dist,
meta: meta, meta: meta,
targetMgr: targetMgr, targetMgr: targetMgr,
targetObserver: targetObserver, targetObserver: targetObserver,
collectionLoadedCount: make(map[int64]int), partitionLoadedCount: make(map[int64]int),
partitionLoadedCount: make(map[int64]int),
} }
} }
@ -115,36 +113,24 @@ func (ob *CollectionObserver) observeTimeout() {
log.Info("observes partitions timeout", zap.Int("partitionNum", len(partitions))) log.Info("observes partitions timeout", zap.Int("partitionNum", len(partitions)))
} }
for collection, partitions := range partitions { for collection, partitions := range partitions {
log := log.With(
zap.Int64("collectionID", collection),
)
for _, partition := range partitions { for _, partition := range partitions {
if partition.GetStatus() != querypb.LoadStatus_Loading || if partition.GetStatus() != querypb.LoadStatus_Loading ||
time.Now().Before(partition.UpdatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds.GetAsDuration(time.Second))) { time.Now().Before(partition.UpdatedAt.Add(Params.QueryCoordCfg.LoadTimeoutSeconds.GetAsDuration(time.Second))) {
continue continue
} }
log.Info("load partition timeout, cancel all partitions", log.Info("load partition timeout, cancel it",
zap.Int64("collectionID", collection),
zap.Int64("partitionID", partition.GetPartitionID()), zap.Int64("partitionID", partition.GetPartitionID()),
zap.Duration("loadTime", time.Since(partition.CreatedAt))) zap.Duration("loadTime", time.Since(partition.CreatedAt)))
// TODO(yah01): Now, releasing part of partitions is not allowed ob.meta.CollectionManager.RemovePartition(partition.GetPartitionID())
ob.meta.CollectionManager.RemoveCollection(partition.GetCollectionID()) ob.targetMgr.RemovePartition(partition.GetCollectionID(), partition.GetPartitionID())
ob.meta.ReplicaManager.RemoveCollection(partition.GetCollectionID())
ob.targetMgr.RemoveCollection(partition.GetCollectionID())
break break
} }
} }
} }
func (ob *CollectionObserver) observeLoadStatus() { func (ob *CollectionObserver) observeLoadStatus() {
collections := ob.meta.CollectionManager.GetAllCollections()
for _, collection := range collections {
if collection.LoadPercentage == 100 {
continue
}
ob.observeCollectionLoadStatus(collection)
}
partitions := ob.meta.CollectionManager.GetAllPartitions() partitions := ob.meta.CollectionManager.GetAllPartitions()
if len(partitions) > 0 { if len(partitions) > 0 {
log.Info("observe partitions status", zap.Int("partitionNum", len(partitions))) log.Info("observe partitions status", zap.Int("partitionNum", len(partitions)))
@ -153,61 +139,30 @@ func (ob *CollectionObserver) observeLoadStatus() {
if partition.LoadPercentage == 100 { if partition.LoadPercentage == 100 {
continue continue
} }
ob.observePartitionLoadStatus(partition) replicaNum := ob.meta.GetReplicaNumber(partition.GetCollectionID())
ob.observePartitionLoadStatus(partition, replicaNum)
}
collections := ob.meta.CollectionManager.GetAllCollections()
for _, collection := range collections {
if collection.LoadPercentage == 100 {
continue
}
ob.observeCollectionLoadStatus(collection)
} }
} }
func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Collection) { func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Collection) {
log := log.With(zap.Int64("collectionID", collection.GetCollectionID())) log := log.With(zap.Int64("collectionID", collection.GetCollectionID()))
segmentTargets := ob.targetMgr.GetHistoricalSegmentsByCollection(collection.GetCollectionID(), meta.NextTarget)
channelTargets := ob.targetMgr.GetDmChannelsByCollection(collection.GetCollectionID(), meta.NextTarget)
targetNum := len(segmentTargets) + len(channelTargets)
log.Info("collection targets",
zap.Int("segmentTargetNum", len(segmentTargets)),
zap.Int("channelTargetNum", len(channelTargets)),
zap.Int("totalTargetNum", targetNum),
zap.Int32("replicaNum", collection.GetReplicaNumber()),
)
updated := collection.Clone() updated := collection.Clone()
loadedCount := 0 percentage := ob.meta.CollectionManager.GetCurrentLoadPercentage(collection.GetCollectionID())
if targetNum == 0 { if percentage <= updated.LoadPercentage {
log.Info("No segment/channel in target need to be loaded!")
updated.LoadPercentage = 100
} else {
for _, channel := range channelTargets {
group := utils.GroupNodesByReplica(ob.meta.ReplicaManager,
collection.GetCollectionID(),
ob.dist.LeaderViewManager.GetChannelDist(channel.GetChannelName()))
loadedCount += len(group)
}
subChannelCount := loadedCount
for _, segment := range segmentTargets {
group := utils.GroupNodesByReplica(ob.meta.ReplicaManager,
collection.GetCollectionID(),
ob.dist.LeaderViewManager.GetSealedSegmentDist(segment.GetID()))
loadedCount += len(group)
}
if loadedCount > 0 {
log.Info("collection load progress",
zap.Int("subChannelCount", subChannelCount),
zap.Int("loadSegmentCount", loadedCount-subChannelCount),
)
}
updated.LoadPercentage = int32(loadedCount * 100 / (targetNum * int(collection.GetReplicaNumber())))
}
if loadedCount <= ob.collectionLoadedCount[collection.GetCollectionID()] &&
updated.LoadPercentage != 100 {
ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount
return return
} }
ob.collectionLoadedCount[collection.GetCollectionID()] = loadedCount updated.LoadPercentage = percentage
if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) { if updated.LoadPercentage == 100 && ob.targetObserver.Check(updated.GetCollectionID()) {
delete(ob.collectionLoadedCount, collection.GetCollectionID())
updated.Status = querypb.LoadStatus_Loaded updated.Status = querypb.LoadStatus_Loaded
ob.meta.CollectionManager.UpdateCollection(updated) ob.meta.CollectionManager.UpdateCollection(updated)
@ -221,7 +176,7 @@ func (ob *CollectionObserver) observeCollectionLoadStatus(collection *meta.Colle
zap.Int32("collectionStatus", int32(updated.GetStatus()))) zap.Int32("collectionStatus", int32(updated.GetStatus())))
} }
func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partition) { func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partition, replicaNum int32) {
log := log.With( log := log.With(
zap.Int64("collectionID", partition.GetCollectionID()), zap.Int64("collectionID", partition.GetCollectionID()),
zap.Int64("partitionID", partition.GetPartitionID()), zap.Int64("partitionID", partition.GetPartitionID()),
@ -234,7 +189,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti
zap.Int("segmentTargetNum", len(segmentTargets)), zap.Int("segmentTargetNum", len(segmentTargets)),
zap.Int("channelTargetNum", len(channelTargets)), zap.Int("channelTargetNum", len(channelTargets)),
zap.Int("totalTargetNum", targetNum), zap.Int("totalTargetNum", targetNum),
zap.Int32("replicaNum", partition.GetReplicaNumber()), zap.Int32("replicaNum", replicaNum),
) )
loadedCount := 0 loadedCount := 0
@ -261,7 +216,7 @@ func (ob *CollectionObserver) observePartitionLoadStatus(partition *meta.Partiti
zap.Int("subChannelCount", subChannelCount), zap.Int("subChannelCount", subChannelCount),
zap.Int("loadSegmentCount", loadedCount-subChannelCount)) zap.Int("loadSegmentCount", loadedCount-subChannelCount))
} }
updated.LoadPercentage = int32(loadedCount * 100 / (targetNum * int(partition.GetReplicaNumber()))) updated.LoadPercentage = int32(loadedCount * 100 / (targetNum * int(replicaNum)))
} }
if loadedCount <= ob.partitionLoadedCount[partition.GetPartitionID()] && if loadedCount <= ob.partitionLoadedCount[partition.GetPartitionID()] &&

View File

@ -21,6 +21,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/samber/lo"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3" clientv3 "go.etcd.io/etcd/client/v3"
@ -200,7 +201,7 @@ func (suite *CollectionObserverSuite) SetupTest() {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe() suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
} }
suite.targetObserver.Start(context.Background()) suite.targetObserver.Start(context.Background())
suite.ob.Start(context.Background())
suite.loadAll() suite.loadAll()
} }
@ -212,22 +213,12 @@ func (suite *CollectionObserverSuite) TearDownTest() {
func (suite *CollectionObserverSuite) TestObserve() { func (suite *CollectionObserverSuite) TestObserve() {
const ( const (
timeout = 2 * time.Second timeout = 3 * time.Second
) )
// time before load // time before load
time := suite.meta.GetCollection(suite.collections[2]).UpdatedAt time := suite.meta.GetCollection(suite.collections[2]).UpdatedAt
// Not timeout // Not timeout
paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "2") paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "3")
segments := []*datapb.SegmentBinlogs{}
for _, segment := range suite.segments[100] {
segments = append(segments, &datapb.SegmentBinlogs{
SegmentID: segment.GetID(),
InsertChannel: segment.GetInsertChannel(),
})
}
suite.ob.Start(context.Background())
// Collection 100 loaded before timeout, // Collection 100 loaded before timeout,
// collection 101 timeout // collection 101 timeout
@ -282,9 +273,45 @@ func (suite *CollectionObserverSuite) TestObserve() {
}, timeout*2, timeout/10) }, timeout*2, timeout/10)
} }
func (suite *CollectionObserverSuite) TestObservePartition() {
const (
timeout = 3 * time.Second
)
paramtable.Get().Save(Params.QueryCoordCfg.LoadTimeoutSeconds.Key, "3")
// Partition 10 loaded
suite.dist.LeaderViewManager.Update(1, &meta.LeaderView{
ID: 1,
CollectionID: 100,
Channel: "100-dmc0",
Segments: map[int64]*querypb.SegmentDist{1: {NodeID: 1, Version: 0}},
})
suite.dist.LeaderViewManager.Update(2, &meta.LeaderView{
ID: 2,
CollectionID: 100,
Channel: "100-dmc1",
Segments: map[int64]*querypb.SegmentDist{2: {NodeID: 2, Version: 0}},
})
// Partition 11 timeout
suite.dist.LeaderViewManager.Update(1, &meta.LeaderView{
ID: 1,
CollectionID: 101,
Channel: "",
Segments: map[int64]*querypb.SegmentDist{},
})
suite.Eventually(func() bool {
return suite.isPartitionLoaded(suite.partitions[100][0])
}, timeout*2, timeout/10)
suite.Eventually(func() bool {
return suite.isPartitionTimeout(suite.collections[1], suite.partitions[101][0])
}, timeout*2, timeout/10)
}
func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool { func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool {
exist := suite.meta.Exist(collection) exist := suite.meta.Exist(collection)
percentage := suite.meta.GetLoadPercentage(collection) percentage := suite.meta.GetCurrentLoadPercentage(collection)
status := suite.meta.GetStatus(collection) status := suite.meta.GetStatus(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(collection) replicas := suite.meta.ReplicaManager.GetByCollection(collection)
channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget) channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget)
@ -298,6 +325,25 @@ func (suite *CollectionObserverSuite) isCollectionLoaded(collection int64) bool
len(segments) == len(suite.segments[collection]) len(segments) == len(suite.segments[collection])
} }
func (suite *CollectionObserverSuite) isPartitionLoaded(partitionID int64) bool {
partition := suite.meta.GetPartition(partitionID)
if partition == nil {
return false
}
collection := partition.GetCollectionID()
percentage := suite.meta.GetPartitionLoadPercentage(partitionID)
status := partition.GetStatus()
channels := suite.targetMgr.GetDmChannelsByCollection(collection, meta.CurrentTarget)
segments := suite.targetMgr.GetHistoricalSegmentsByPartition(collection, partitionID, meta.CurrentTarget)
expectedSegments := lo.Filter(suite.segments[collection], func(seg *datapb.SegmentInfo, _ int) bool {
return seg.PartitionID == partitionID
})
return percentage == 100 &&
status == querypb.LoadStatus_Loaded &&
len(channels) == len(suite.channels[collection]) &&
len(segments) == len(expectedSegments)
}
func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool { func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool {
exist := suite.meta.Exist(collection) exist := suite.meta.Exist(collection)
replicas := suite.meta.ReplicaManager.GetByCollection(collection) replicas := suite.meta.ReplicaManager.GetByCollection(collection)
@ -309,9 +355,14 @@ func (suite *CollectionObserverSuite) isCollectionTimeout(collection int64) bool
len(segments) > 0) len(segments) > 0)
} }
func (suite *CollectionObserverSuite) isPartitionTimeout(collection int64, partitionID int64) bool {
partition := suite.meta.GetPartition(partitionID)
segments := suite.targetMgr.GetHistoricalSegmentsByPartition(collection, partitionID, meta.CurrentTarget)
return partition == nil && len(segments) == 0
}
func (suite *CollectionObserverSuite) isCollectionLoadedContinue(collection int64, beforeTime time.Time) bool { func (suite *CollectionObserverSuite) isCollectionLoadedContinue(collection int64, beforeTime time.Time) bool {
return suite.meta.GetCollection(collection).UpdatedAt.After(beforeTime) return suite.meta.GetCollection(collection).UpdatedAt.After(beforeTime)
} }
func (suite *CollectionObserverSuite) loadAll() { func (suite *CollectionObserverSuite) loadAll() {
@ -332,32 +383,31 @@ func (suite *CollectionObserverSuite) load(collection int64) {
err = suite.meta.ReplicaManager.Put(replicas...) err = suite.meta.ReplicaManager.Put(replicas...)
suite.NoError(err) suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { suite.meta.PutCollection(&meta.Collection{
suite.meta.PutCollection(&meta.Collection{ CollectionLoadInfo: &querypb.CollectionLoadInfo{
CollectionLoadInfo: &querypb.CollectionLoadInfo{ CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection],
Status: querypb.LoadStatus_Loading,
LoadType: suite.loadTypes[collection],
},
LoadPercentage: 0,
CreatedAt: time.Now(),
})
for _, partition := range suite.partitions[collection] {
suite.meta.PutPartition(&meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection, CollectionID: collection,
PartitionID: partition,
ReplicaNumber: suite.replicaNumber[collection], ReplicaNumber: suite.replicaNumber[collection],
Status: querypb.LoadStatus_Loading, Status: querypb.LoadStatus_Loading,
}, },
LoadPercentage: 0, LoadPercentage: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
}) })
} else {
for _, partition := range suite.partitions[collection] {
suite.meta.PutPartition(&meta.Partition{
PartitionLoadInfo: &querypb.PartitionLoadInfo{
CollectionID: collection,
PartitionID: partition,
ReplicaNumber: suite.replicaNumber[collection],
Status: querypb.LoadStatus_Loading,
},
LoadPercentage: 0,
CreatedAt: time.Now(),
})
}
} }
allSegments := make([]*datapb.SegmentBinlogs, 0) allSegments := make(map[int64][]*datapb.SegmentBinlogs, 0) // partitionID -> segments
dmChannels := make([]*datapb.VchannelInfo, 0) dmChannels := make([]*datapb.VchannelInfo, 0)
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
dmChannels = append(dmChannels, &datapb.VchannelInfo{ dmChannels = append(dmChannels, &datapb.VchannelInfo{
@ -367,16 +417,15 @@ func (suite *CollectionObserverSuite) load(collection int64) {
} }
for _, segment := range suite.segments[collection] { for _, segment := range suite.segments[collection] {
allSegments = append(allSegments, &datapb.SegmentBinlogs{ allSegments[segment.PartitionID] = append(allSegments[segment.PartitionID], &datapb.SegmentBinlogs{
SegmentID: segment.GetID(), SegmentID: segment.GetID(),
InsertChannel: segment.GetInsertChannel(), InsertChannel: segment.GetInsertChannel(),
}) })
} }
partitions := suite.partitions[collection] partitions := suite.partitions[collection]
for _, partition := range partitions { for _, partition := range partitions {
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, partition).Return(dmChannels, allSegments[partition], nil)
} }
suite.targetMgr.UpdateCollectionNextTargetWithPartitions(collection, partitions...) suite.targetMgr.UpdateCollectionNextTargetWithPartitions(collection, partitions...)
} }

View File

@ -97,6 +97,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegments() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -152,6 +153,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncLoadedSegments() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -209,6 +211,8 @@ func (suite *LeaderObserverTestSuite) TestIgnoreBalancedSegment() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -247,6 +251,7 @@ func (suite *LeaderObserverTestSuite) TestSyncLoadedSegmentsWithReplicas() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))
@ -340,6 +345,7 @@ func (suite *LeaderObserverTestSuite) TestIgnoreSyncRemovedSegments() {
ChannelName: "test-insert-channel", ChannelName: "test-insert-channel",
}, },
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, int64(1)).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return( suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, int64(1), int64(1)).Return(
channels, segments, nil) channels, segments, nil)
observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1)) observer.target.UpdateCollectionNextTargetWithPartitions(int64(1), int64(1))

View File

@ -37,6 +37,7 @@ type checkRequest struct {
type targetUpdateRequest struct { type targetUpdateRequest struct {
CollectionID int64 CollectionID int64
PartitionIDs []int64
Notifier chan error Notifier chan error
ReadyNotifier chan struct{} ReadyNotifier chan struct{}
} }
@ -108,7 +109,7 @@ func (ob *TargetObserver) schedule(ctx context.Context) {
req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID) req.Notifier <- ob.targetMgr.IsCurrentTargetExist(req.CollectionID)
case req := <-ob.updateChan: case req := <-ob.updateChan:
err := ob.updateNextTarget(req.CollectionID) err := ob.updateNextTarget(req.CollectionID, req.PartitionIDs...)
if err != nil { if err != nil {
close(req.ReadyNotifier) close(req.ReadyNotifier)
} else { } else {
@ -148,13 +149,14 @@ func (ob *TargetObserver) check(collectionID int64) {
// UpdateNextTarget updates the next target, // UpdateNextTarget updates the next target,
// returns a channel which will be closed when the next target is ready, // returns a channel which will be closed when the next target is ready,
// or returns error if failed to pull target // or returns error if failed to pull target
func (ob *TargetObserver) UpdateNextTarget(collectionID int64) (chan struct{}, error) { func (ob *TargetObserver) UpdateNextTarget(collectionID int64, partitionIDs ...int64) (chan struct{}, error) {
notifier := make(chan error) notifier := make(chan error)
readyCh := make(chan struct{}) readyCh := make(chan struct{})
defer close(notifier) defer close(notifier)
ob.updateChan <- targetUpdateRequest{ ob.updateChan <- targetUpdateRequest{
CollectionID: collectionID, CollectionID: collectionID,
PartitionIDs: partitionIDs,
Notifier: notifier, Notifier: notifier,
ReadyNotifier: readyCh, ReadyNotifier: readyCh,
} }
@ -208,11 +210,16 @@ func (ob *TargetObserver) isNextTargetExpired(collectionID int64) bool {
return time.Since(ob.nextTargetLastUpdate[collectionID]) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second) return time.Since(ob.nextTargetLastUpdate[collectionID]) > params.Params.QueryCoordCfg.NextTargetSurviveTime.GetAsDuration(time.Second)
} }
func (ob *TargetObserver) updateNextTarget(collectionID int64) error { func (ob *TargetObserver) updateNextTarget(collectionID int64, partitionIDs ...int64) error {
log := log.With(zap.Int64("collectionID", collectionID)) log := log.With(zap.Int64("collectionID", collectionID), zap.Int64s("partIDs", partitionIDs))
log.Info("observer trigger update next target") log.Info("observer trigger update next target")
err := ob.targetMgr.UpdateCollectionNextTarget(collectionID) var err error
if len(partitionIDs) == 0 {
err = ob.targetMgr.UpdateCollectionNextTarget(collectionID)
} else {
err = ob.targetMgr.UpdateCollectionNextTargetWithPartitions(collectionID, partitionIDs...)
}
if err != nil { if err != nil {
log.Error("failed to update next target for collection", log.Error("failed to update next target for collection",
zap.Error(err)) zap.Error(err))

View File

@ -87,6 +87,8 @@ func (suite *TargetObserverSuite) SetupTest() {
err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1)) err = suite.meta.CollectionManager.PutCollection(utils.CreateTestCollection(suite.collectionID, 1))
suite.NoError(err) suite.NoError(err)
err = suite.meta.CollectionManager.PutPartition(utils.CreateTestPartition(suite.collectionID, suite.partitionID))
suite.NoError(err)
replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, 1, meta.DefaultResourceGroupName) replicas, err := suite.meta.ReplicaManager.Spawn(suite.collectionID, 1, meta.DefaultResourceGroupName)
suite.NoError(err) suite.NoError(err)
replicas[0].AddNode(2) replicas[0].AddNode(2)
@ -115,8 +117,8 @@ func (suite *TargetObserverSuite) SetupTest() {
}, },
} }
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything).Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil) suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything).Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
} }
func (suite *TargetObserverSuite) TestTriggerUpdateTarget() { func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
@ -158,12 +160,10 @@ func (suite *TargetObserverSuite) TestTriggerUpdateTarget() {
suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID) suite.targetMgr.UpdateCollectionCurrentTarget(suite.collectionID)
// Pull next again // Pull next again
suite.broker.EXPECT().GetPartitions(mock.Anything, mock.Anything).Return([]int64{suite.partitionID}, nil)
suite.broker.EXPECT(). suite.broker.EXPECT().
GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything). GetRecoveryInfo(mock.Anything, mock.Anything, mock.Anything).
Return(suite.nextTargetChannels, suite.nextTargetSegments, nil) Return(suite.nextTargetChannels, suite.nextTargetSegments, nil)
suite.broker.EXPECT().
GetPartitions(mock.Anything, mock.Anything).
Return([]int64{suite.partitionID}, nil)
suite.Eventually(func() bool { suite.Eventually(func() bool {
return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 && return len(suite.targetMgr.GetHistoricalSegmentsByCollection(suite.collectionID, meta.NextTarget)) == 3 &&
len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2 len(suite.targetMgr.GetDmChannelsByCollection(suite.collectionID, meta.NextTarget)) == 2

View File

@ -286,8 +286,13 @@ func (s *Server) initMeta() error {
s.store = meta.NewMetaStore(s.kv) s.store = meta.NewMetaStore(s.kv)
s.meta = meta.NewMeta(s.idAllocator, s.store, s.nodeMgr) s.meta = meta.NewMeta(s.idAllocator, s.store, s.nodeMgr)
s.broker = meta.NewCoordinatorBroker(
s.dataCoord,
s.rootCoord,
)
log.Info("recover meta...") log.Info("recover meta...")
err := s.meta.CollectionManager.Recover() err := s.meta.CollectionManager.Recover(s.broker)
if err != nil { if err != nil {
log.Error("failed to recover collections") log.Error("failed to recover collections")
return err return err
@ -295,6 +300,7 @@ func (s *Server) initMeta() error {
collections := s.meta.GetAll() collections := s.meta.GetAll()
log.Info("recovering collections...", zap.Int64s("collections", collections)) log.Info("recovering collections...", zap.Int64s("collections", collections))
metrics.QueryCoordNumCollections.WithLabelValues().Set(float64(len(collections))) metrics.QueryCoordNumCollections.WithLabelValues().Set(float64(len(collections)))
metrics.QueryCoordNumPartitions.WithLabelValues().Set(float64(len(s.meta.GetAllPartitions())))
err = s.meta.ReplicaManager.Recover(collections) err = s.meta.ReplicaManager.Recover(collections)
if err != nil { if err != nil {
@ -313,10 +319,6 @@ func (s *Server) initMeta() error {
ChannelDistManager: meta.NewChannelDistManager(), ChannelDistManager: meta.NewChannelDistManager(),
LeaderViewManager: meta.NewLeaderViewManager(), LeaderViewManager: meta.NewLeaderViewManager(),
} }
s.broker = meta.NewCoordinatorBroker(
s.dataCoord,
s.rootCoord,
)
s.targetMgr = meta.NewTargetManager(s.broker, s.meta) s.targetMgr = meta.NewTargetManager(s.broker, s.meta)
record.Record("Server initMeta") record.Record("Server initMeta")

View File

@ -23,7 +23,6 @@ import (
"time" "time"
"github.com/cockroachdb/errors" "github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -42,6 +41,7 @@ import (
"github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/params"
"github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/session"
"github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/task"
"github.com/milvus-io/milvus/internal/querycoordv2/utils"
"github.com/milvus-io/milvus/internal/util/commonpbutil" "github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/merr" "github.com/milvus-io/milvus/internal/util/merr"
@ -116,6 +116,7 @@ func (suite *ServerSuite) SetupTest() {
ok := suite.waitNodeUp(suite.nodes[i], 5*time.Second) ok := suite.waitNodeUp(suite.nodes[i], 5*time.Second)
suite.Require().True(ok) suite.Require().True(ok)
suite.server.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, suite.nodes[i].ID) suite.server.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, suite.nodes[i].ID)
suite.expectLoadAndReleasePartitions(suite.nodes[i])
} }
suite.loadAll() suite.loadAll()
@ -158,14 +159,15 @@ func (suite *ServerSuite) TestRecoverFailed() {
suite.NoError(err) suite.NoError(err)
broker := meta.NewMockBroker(suite.T()) broker := meta.NewMockBroker(suite.T())
broker.EXPECT().GetPartitions(context.TODO(), int64(1000)).Return(nil, errors.New("CollectionNotExist")) for _, collection := range suite.collections {
broker.EXPECT().GetRecoveryInfo(context.TODO(), int64(1001), mock.Anything).Return(nil, nil, errors.New("CollectionNotExist")) broker.EXPECT().GetPartitions(mock.Anything, collection).Return([]int64{1}, nil)
broker.EXPECT().GetRecoveryInfo(context.TODO(), collection, mock.Anything).Return(nil, nil, errors.New("CollectionNotExist"))
}
suite.server.targetMgr = meta.NewTargetManager(broker, suite.server.meta) suite.server.targetMgr = meta.NewTargetManager(broker, suite.server.meta)
err = suite.server.Start() err = suite.server.Start()
suite.NoError(err) suite.NoError(err)
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.False(suite.server.meta.Exist(collection))
suite.Nil(suite.server.targetMgr.GetDmChannelsByCollection(collection, meta.NextTarget)) suite.Nil(suite.server.targetMgr.GetDmChannelsByCollection(collection, meta.NextTarget))
} }
} }
@ -259,20 +261,17 @@ func (suite *ServerSuite) TestEnableActiveStandby() {
Schema: &schemapb.CollectionSchema{}, Schema: &schemapb.CollectionSchema{},
}, nil).Maybe() }, nil).Maybe()
for _, collection := range suite.collections { for _, collection := range suite.collections {
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { req := &milvuspb.ShowPartitionsRequest{
req := &milvuspb.ShowPartitionsRequest{ Base: commonpbutil.NewMsgBase(
Base: commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), ),
), CollectionID: collection,
CollectionID: collection,
}
mockRootCoord.EXPECT().ShowPartitionsInternal(mock.Anything, req).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil),
PartitionIDs: suite.partitions[collection],
}, nil).Maybe()
} }
mockRootCoord.EXPECT().ShowPartitionsInternal(mock.Anything, req).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Status(nil),
PartitionIDs: suite.partitions[collection],
}, nil).Maybe()
suite.expectGetRecoverInfoByMockDataCoord(collection, mockDataCoord) suite.expectGetRecoverInfoByMockDataCoord(collection, mockDataCoord)
} }
err = suite.server.SetRootCoord(mockRootCoord) err = suite.server.SetRootCoord(mockRootCoord)
suite.NoError(err) suite.NoError(err)
@ -385,6 +384,11 @@ func (suite *ServerSuite) expectGetRecoverInfo(collection int64) {
} }
} }
func (suite *ServerSuite) expectLoadAndReleasePartitions(querynode *mocks.MockQueryNode) {
querynode.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil).Maybe()
querynode.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil).Maybe()
}
func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.DataCoord) { func (suite *ServerSuite) expectGetRecoverInfoByMockDataCoord(collection int64, dataCoord *coordMocks.DataCoord) {
var ( var (
vChannels []*datapb.VchannelInfo vChannels []*datapb.VchannelInfo
@ -432,7 +436,7 @@ func (suite *ServerSuite) updateCollectionStatus(collectionID int64, status quer
} }
collection.CollectionLoadInfo.Status = status collection.CollectionLoadInfo.Status = status
suite.server.meta.UpdateCollection(collection) suite.server.meta.UpdateCollection(collection)
} else {
partitions := suite.server.meta.GetPartitionsByCollection(collectionID) partitions := suite.server.meta.GetPartitionsByCollection(collectionID)
for _, partition := range partitions { for _, partition := range partitions {
partition := partition.Clone() partition := partition.Clone()
@ -488,9 +492,7 @@ func (suite *ServerSuite) hackServer() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe() suite.broker.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything).Return(&schemapb.CollectionSchema{}, nil).Maybe()
for _, collection := range suite.collections { for _, collection := range suite.collections {
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil).Maybe()
}
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
} }
log.Debug("server hacked") log.Debug("server hacked")

View File

@ -76,9 +76,6 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
for _, collection := range s.meta.GetAllCollections() { for _, collection := range s.meta.GetAllCollections() {
collectionSet.Insert(collection.GetCollectionID()) collectionSet.Insert(collection.GetCollectionID())
} }
for _, partition := range s.meta.GetAllPartitions() {
collectionSet.Insert(partition.GetCollectionID())
}
isGetAll = true isGetAll = true
} }
collections := collectionSet.Collect() collections := collectionSet.Collect()
@ -92,7 +89,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio
for _, collectionID := range collections { for _, collectionID := range collections {
log := log.With(zap.Int64("collectionID", collectionID)) log := log.With(zap.Int64("collectionID", collectionID))
percentage := s.meta.CollectionManager.GetLoadPercentage(collectionID) percentage := s.meta.CollectionManager.GetCollectionLoadPercentage(collectionID)
if percentage < 0 { if percentage < 0 {
if isGetAll { if isGetAll {
// The collection is released during this, // The collection is released during this,
@ -139,67 +136,33 @@ func (s *Server) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions
} }
defer meta.GlobalFailedLoadCache.TryExpire() defer meta.GlobalFailedLoadCache.TryExpire()
// TODO(yah01): now, for load collection, the percentage of partition is equal to the percentage of collection,
// we can calculates the real percentage of partitions
partitions := req.GetPartitionIDs() partitions := req.GetPartitionIDs()
percentages := make([]int64, 0) percentages := make([]int64, 0)
isReleased := false
switch s.meta.GetLoadType(req.GetCollectionID()) {
case querypb.LoadType_LoadCollection:
percentage := s.meta.GetLoadPercentage(req.GetCollectionID())
if percentage < 0 {
isReleased = true
break
}
if len(partitions) == 0 { if len(partitions) == 0 {
var err error partitions = lo.Map(s.meta.GetPartitionsByCollection(req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 {
partitions, err = s.broker.GetPartitions(ctx, req.GetCollectionID()) return partition.GetPartitionID()
})
}
for _, partitionID := range partitions {
percentage := s.meta.GetPartitionLoadPercentage(partitionID)
if percentage < 0 {
err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID())
if err != nil { if err != nil {
msg := "failed to show partitions" status := merr.Status(err)
log.Warn(msg, zap.Error(err)) status.ErrorCode = commonpb.ErrorCode_InsufficientMemoryToLoad
log.Warn("show partition failed", zap.Error(err))
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg, err), Status: status,
}, nil }, nil
} }
} msg := fmt.Sprintf("partition %d has not been loaded to memory or load failed", partitionID)
for range partitions { log.Warn(msg)
percentages = append(percentages, int64(percentage))
}
case querypb.LoadType_LoadPartition:
if len(partitions) == 0 {
partitions = lo.Map(s.meta.GetPartitionsByCollection(req.GetCollectionID()), func(partition *meta.Partition, _ int) int64 {
return partition.GetPartitionID()
})
}
for _, partitionID := range partitions {
partition := s.meta.GetPartition(partitionID)
if partition == nil {
isReleased = true
break
}
percentages = append(percentages, int64(partition.LoadPercentage))
}
default:
isReleased = true
}
if isReleased {
err := meta.GlobalFailedLoadCache.Get(req.GetCollectionID())
if err != nil {
status := merr.Status(err)
status.ErrorCode = commonpb.ErrorCode_InsufficientMemoryToLoad
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
Status: status, Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg),
}, nil }, nil
} }
msg := fmt.Sprintf("collection %v has not been loaded into QueryNode", req.GetCollectionID()) percentages = append(percentages, int64(percentage))
log.Warn(msg)
return &querypb.ShowPartitionsResponse{
Status: utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg),
}, nil
} }
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
@ -246,7 +209,9 @@ func (s *Server) LoadCollection(ctx context.Context, req *querypb.LoadCollection
req, req,
s.dist, s.dist,
s.meta, s.meta,
s.cluster,
s.targetMgr, s.targetMgr,
s.targetObserver,
s.broker, s.broker,
s.nodeMgr, s.nodeMgr,
) )
@ -340,7 +305,9 @@ func (s *Server) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions
req, req,
s.dist, s.dist,
s.meta, s.meta,
s.cluster,
s.targetMgr, s.targetMgr,
s.targetObserver,
s.broker, s.broker,
s.nodeMgr, s.nodeMgr,
) )
@ -401,6 +368,7 @@ func (s *Server) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart
req, req,
s.dist, s.dist,
s.meta, s.meta,
s.cluster,
s.targetMgr, s.targetMgr,
s.targetObserver, s.targetObserver,
) )
@ -528,6 +496,31 @@ func (s *Server) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo
}, nil }, nil
} }
func (s *Server) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64("partitionID", req.GetPartitionID()),
)
log.Info("received sync new created partition request")
failedMsg := "failed to sync new created partition"
if s.status.Load() != commonpb.StateCode_Healthy {
log.Warn(failedMsg, zap.Error(ErrNotHealthy))
return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, failedMsg, ErrNotHealthy), nil
}
syncJob := job.NewSyncNewCreatedPartitionJob(ctx, req, s.meta, s.cluster)
s.jobScheduler.Add(syncJob)
err := syncJob.Wait()
if err != nil && !errors.Is(err, job.ErrPartitionNotInTarget) {
log.Warn(failedMsg, zap.Error(err))
return utils.WrapStatus(errCode(err), failedMsg, err), nil
}
return merr.Status(nil), nil
}
// refreshCollection must be called after loading a collection. It looks for new segments that are not loaded yet and // refreshCollection must be called after loading a collection. It looks for new segments that are not loaded yet and
// tries to load them up. It returns when all segments of the given collection are loaded, or when error happens. // tries to load them up. It returns when all segments of the given collection are loaded, or when error happens.
// Note that a collection's loading progress always stays at 100% after a successful load and will not get updated // Note that a collection's loading progress always stays at 100% after a successful load and will not get updated
@ -547,7 +540,7 @@ func (s *Server) refreshCollection(ctx context.Context, collID int64) (*commonpb
} }
// Check that collection is fully loaded. // Check that collection is fully loaded.
if s.meta.CollectionManager.GetLoadPercentage(collID) != 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(collID) != 100 {
errMsg := "a collection must be fully loaded before refreshing" errMsg := "a collection must be fully loaded before refreshing"
log.Warn(errMsg) log.Warn(errMsg)
return &commonpb.Status{ return &commonpb.Status{
@ -601,7 +594,7 @@ func (s *Server) refreshPartitions(ctx context.Context, collID int64, partIDs []
} }
// Check that all partitions are fully loaded. // Check that all partitions are fully loaded.
if s.meta.CollectionManager.GetLoadPercentage(collID) != 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(collID) != 100 {
errMsg := "partitions must be fully loaded before refreshing" errMsg := "partitions must be fully loaded before refreshing"
log.Warn(errMsg) log.Warn(errMsg)
return &commonpb.Status{ return &commonpb.Status{
@ -671,7 +664,7 @@ func (s *Server) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques
log.Warn(msg, zap.Int("source-nodes-num", len(req.GetSourceNodeIDs()))) log.Warn(msg, zap.Int("source-nodes-num", len(req.GetSourceNodeIDs())))
return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil
} }
if s.meta.CollectionManager.GetLoadPercentage(req.GetCollectionID()) < 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(req.GetCollectionID()) < 100 {
msg := "can't balance segments of not fully loaded collection" msg := "can't balance segments of not fully loaded collection"
log.Warn(msg) log.Warn(msg)
return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil return utils.WrapStatus(commonpb.ErrorCode_UnexpectedError, msg), nil
@ -845,7 +838,7 @@ func (s *Server) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeade
Status: merr.Status(nil), Status: merr.Status(nil),
} }
if s.meta.CollectionManager.GetLoadPercentage(req.GetCollectionID()) < 100 { if s.meta.CollectionManager.GetCurrentLoadPercentage(req.GetCollectionID()) < 100 {
msg := fmt.Sprintf("collection %v is not fully loaded", req.GetCollectionID()) msg := fmt.Sprintf("collection %v is not fully loaded", req.GetCollectionID())
log.Warn(msg) log.Warn(msg)
resp.Status = utils.WrapStatus(commonpb.ErrorCode_NoReplicaAvailable, msg) resp.Status = utils.WrapStatus(commonpb.ErrorCode_NoReplicaAvailable, msg)

View File

@ -142,6 +142,7 @@ func (suite *ServiceSuite) SetupTest() {
suite.dist, suite.dist,
suite.broker, suite.broker,
) )
suite.targetObserver.Start(context.Background())
for _, node := range suite.nodes { for _, node := range suite.nodes {
suite.nodeMgr.Add(session.NewNodeInfo(node, "localhost")) suite.nodeMgr.Add(session.NewNodeInfo(node, "localhost"))
err := suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node) err := suite.meta.ResourceManager.AssignNode(meta.DefaultResourceGroupName, node)
@ -311,7 +312,6 @@ func (suite *ServiceSuite) TestLoadCollection() {
// Test load all collections // Test load all collections
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
@ -776,6 +776,10 @@ func (suite *ServiceSuite) TestLoadPartition() {
// Test load all partitions // Test load all partitions
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).
Return(append(suite.partitions[collection], 999), nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, collection, int64(999)).
Return(nil, nil, nil)
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
req := &querypb.LoadPartitionsRequest{ req := &querypb.LoadPartitionsRequest{
@ -808,6 +812,36 @@ func (suite *ServiceSuite) TestLoadPartition() {
suite.NoError(err) suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode) suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
// Test load with collection loaded
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
// Test load with more partitions
suite.cluster.EXPECT().LoadPartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 999),
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
// Test when server is not healthy // Test when server is not healthy
server.UpdateStateCode(commonpb.StateCode_Initializing) server.UpdateStateCode(commonpb.StateCode_Initializing)
req = &querypb.LoadPartitionsRequest{ req = &querypb.LoadPartitionsRequest{
@ -836,36 +870,6 @@ func (suite *ServiceSuite) TestLoadPartitionFailed() {
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode) suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error()) suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
} }
// Test load with collection loaded
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadCollection {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: suite.partitions[collection],
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
// Test load with more partitions
for _, collection := range suite.collections {
if suite.loadTypes[collection] != querypb.LoadType_LoadPartition {
continue
}
req := &querypb.LoadPartitionsRequest{
CollectionID: collection,
PartitionIDs: append(suite.partitions[collection], 999),
}
resp, err := server.LoadPartitions(ctx, req)
suite.NoError(err)
suite.Equal(commonpb.ErrorCode_IllegalArgument, resp.ErrorCode)
suite.Contains(resp.Reason, job.ErrLoadParameterMismatched.Error())
}
} }
func (suite *ServiceSuite) TestReleaseCollection() { func (suite *ServiceSuite) TestReleaseCollection() {
@ -910,6 +914,8 @@ func (suite *ServiceSuite) TestReleasePartition() {
server := suite.server server := suite.server
// Test release all partitions // Test release all partitions
suite.cluster.EXPECT().ReleasePartitions(mock.Anything, mock.Anything, mock.Anything).
Return(utils.WrapStatus(commonpb.ErrorCode_Success, ""), nil)
for _, collection := range suite.collections { for _, collection := range suite.collections {
req := &querypb.ReleasePartitionsRequest{ req := &querypb.ReleasePartitionsRequest{
CollectionID: collection, CollectionID: collection,
@ -917,11 +923,7 @@ func (suite *ServiceSuite) TestReleasePartition() {
} }
resp, err := server.ReleasePartitions(ctx, req) resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err) suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
} else {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...) suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
} }
@ -933,11 +935,7 @@ func (suite *ServiceSuite) TestReleasePartition() {
} }
resp, err := server.ReleasePartitions(ctx, req) resp, err := server.ReleasePartitions(ctx, req)
suite.NoError(err) suite.NoError(err)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
suite.Equal(commonpb.ErrorCode_UnexpectedError, resp.ErrorCode)
} else {
suite.Equal(commonpb.ErrorCode_Success, resp.ErrorCode)
}
suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...) suite.assertPartitionLoaded(collection, suite.partitions[collection][1:]...)
} }
@ -957,7 +955,6 @@ func (suite *ServiceSuite) TestRefreshCollection() {
defer cancel() defer cancel()
server := suite.server server := suite.server
suite.targetObserver.Start(context.Background())
suite.server.collectionObserver.Start(context.Background()) suite.server.collectionObserver.Start(context.Background())
// Test refresh all collections. // Test refresh all collections.
@ -970,7 +967,6 @@ func (suite *ServiceSuite) TestRefreshCollection() {
// Test load all collections // Test load all collections
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
@ -1023,7 +1019,6 @@ func (suite *ServiceSuite) TestRefreshPartitions() {
defer cancel() defer cancel()
server := suite.server server := suite.server
suite.targetObserver.Start(context.Background())
suite.server.collectionObserver.Start(context.Background()) suite.server.collectionObserver.Start(context.Background())
// Test refresh all partitions. // Test refresh all partitions.
@ -1636,8 +1631,6 @@ func (suite *ServiceSuite) loadAll() {
for _, collection := range suite.collections { for _, collection := range suite.collections {
suite.expectGetRecoverInfo(collection) suite.expectGetRecoverInfo(collection)
if suite.loadTypes[collection] == querypb.LoadType_LoadCollection { if suite.loadTypes[collection] == querypb.LoadType_LoadCollection {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
req := &querypb.LoadCollectionRequest{ req := &querypb.LoadCollectionRequest{
CollectionID: collection, CollectionID: collection,
ReplicaNumber: suite.replicaNumber[collection], ReplicaNumber: suite.replicaNumber[collection],
@ -1647,7 +1640,9 @@ func (suite *ServiceSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -1669,7 +1664,9 @@ func (suite *ServiceSuite) loadAll() {
req, req,
suite.dist, suite.dist,
suite.meta, suite.meta,
suite.cluster,
suite.targetMgr, suite.targetMgr,
suite.targetObserver,
suite.broker, suite.broker,
suite.nodeMgr, suite.nodeMgr,
) )
@ -1741,6 +1738,7 @@ func (suite *ServiceSuite) assertSegments(collection int64, segments []*querypb.
} }
func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) { func (suite *ServiceSuite) expectGetRecoverInfo(collection int64) {
suite.broker.EXPECT().GetPartitions(mock.Anything, collection).Return(suite.partitions[collection], nil)
vChannels := []*datapb.VchannelInfo{} vChannels := []*datapb.VchannelInfo{}
for _, channel := range suite.channels[collection] { for _, channel := range suite.channels[collection] {
vChannels = append(vChannels, &datapb.VchannelInfo{ vChannels = append(vChannels, &datapb.VchannelInfo{
@ -1848,7 +1846,7 @@ func (suite *ServiceSuite) updateCollectionStatus(collectionID int64, status que
} }
collection.CollectionLoadInfo.Status = status collection.CollectionLoadInfo.Status = status
suite.meta.UpdateCollection(collection) suite.meta.UpdateCollection(collection)
} else {
partitions := suite.meta.GetPartitionsByCollection(collectionID) partitions := suite.meta.GetPartitionsByCollection(collectionID)
for _, partition := range partitions { for _, partition := range partitions {
partition := partition.Clone() partition := partition.Clone()
@ -1869,6 +1867,10 @@ func (suite *ServiceSuite) fetchHeartbeats(time time.Time) {
} }
} }
func (suite *ServiceSuite) TearDownTest() {
suite.targetObserver.Stop()
}
func TestService(t *testing.T) { func TestService(t *testing.T) {
suite.Run(t, new(ServiceSuite)) suite.Run(t, new(ServiceSuite))
} }

View File

@ -53,6 +53,8 @@ type Cluster interface {
UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) UnsubDmChannel(ctx context.Context, nodeID int64, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error)
LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error)
ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
LoadPartitions(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error)
ReleasePartitions(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)
GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error)
GetMetrics(ctx context.Context, nodeID int64, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) GetMetrics(ctx context.Context, nodeID int64, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)
SyncDistribution(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) SyncDistribution(ctx context.Context, nodeID int64, req *querypb.SyncDistributionRequest) (*commonpb.Status, error)
@ -174,6 +176,34 @@ func (c *QueryCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *q
return status, err return status, err
} }
func (c *QueryCluster) LoadPartitions(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.LoadPartitionsRequest)
req.Base.TargetID = nodeID
status, err = cli.LoadPartitions(ctx, req)
})
if err1 != nil {
return nil, err1
}
return status, err
}
func (c *QueryCluster) ReleasePartitions(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
var status *commonpb.Status
var err error
err1 := c.send(ctx, nodeID, func(cli types.QueryNode) {
req := proto.Clone(req).(*querypb.ReleasePartitionsRequest)
req.Base.TargetID = nodeID
status, err = cli.ReleasePartitions(ctx, req)
})
if err1 != nil {
return nil, err1
}
return status, err
}
func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { func (c *QueryCluster) GetDataDistribution(ctx context.Context, nodeID int64, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) {
var resp *querypb.GetDataDistributionResponse var resp *querypb.GetDataDistributionResponse
var err error var err error

View File

@ -124,6 +124,14 @@ func (suite *ClusterTestSuite) createDefaultMockServer() querypb.QueryNodeServer
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"), mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"),
).Maybe().Return(succStatus, nil) ).Maybe().Return(succStatus, nil)
svr.EXPECT().LoadPartitions(
mock.Anything,
mock.AnythingOfType("*querypb.LoadPartitionsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().ReleasePartitions(
mock.Anything,
mock.AnythingOfType("*querypb.ReleasePartitionsRequest"),
).Maybe().Return(succStatus, nil)
svr.EXPECT().GetDataDistribution( svr.EXPECT().GetDataDistribution(
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.GetDataDistributionRequest"), mock.AnythingOfType("*querypb.GetDataDistributionRequest"),
@ -169,6 +177,14 @@ func (suite *ClusterTestSuite) createFailedMockServer() querypb.QueryNodeServer
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"), mock.AnythingOfType("*querypb.ReleaseSegmentsRequest"),
).Maybe().Return(failStatus, nil) ).Maybe().Return(failStatus, nil)
svr.EXPECT().LoadPartitions(
mock.Anything,
mock.AnythingOfType("*querypb.LoadPartitionsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().ReleasePartitions(
mock.Anything,
mock.AnythingOfType("*querypb.ReleasePartitionsRequest"),
).Maybe().Return(failStatus, nil)
svr.EXPECT().GetDataDistribution( svr.EXPECT().GetDataDistribution(
mock.Anything, mock.Anything,
mock.AnythingOfType("*querypb.GetDataDistributionRequest"), mock.AnythingOfType("*querypb.GetDataDistributionRequest"),
@ -284,6 +300,45 @@ func (suite *ClusterTestSuite) TestReleaseSegments() {
}, status) }, status)
} }
func (suite *ClusterTestSuite) TestLoadAndReleasePartitions() {
ctx := context.TODO()
status, err := suite.cluster.LoadPartitions(ctx, 0, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, status)
status, err = suite.cluster.LoadPartitions(ctx, 1, &querypb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected error",
}, status)
status, err = suite.cluster.ReleasePartitions(ctx, 0, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, status)
status, err = suite.cluster.ReleasePartitions(ctx, 1, &querypb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{},
})
suite.NoError(err)
suite.Equal(&commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "unexpected error",
}, status)
}
func (suite *ClusterTestSuite) TestGetDataDistribution() { func (suite *ClusterTestSuite) TestGetDataDistribution() {
ctx := context.TODO() ctx := context.TODO()
resp, err := suite.cluster.GetDataDistribution(ctx, 0, &querypb.GetDataDistributionRequest{ resp, err := suite.cluster.GetDataDistribution(ctx, 0, &querypb.GetDataDistributionRequest{

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package session package session
@ -56,8 +56,8 @@ type MockCluster_GetComponentStates_Call struct {
} }
// GetComponentStates is a helper method to define mock.On call // GetComponentStates is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
func (_e *MockCluster_Expecter) GetComponentStates(ctx interface{}, nodeID interface{}) *MockCluster_GetComponentStates_Call { func (_e *MockCluster_Expecter) GetComponentStates(ctx interface{}, nodeID interface{}) *MockCluster_GetComponentStates_Call {
return &MockCluster_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx, nodeID)} return &MockCluster_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx, nodeID)}
} }
@ -103,9 +103,9 @@ type MockCluster_GetDataDistribution_Call struct {
} }
// GetDataDistribution is a helper method to define mock.On call // GetDataDistribution is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.GetDataDistributionRequest // - req *querypb.GetDataDistributionRequest
func (_e *MockCluster_Expecter) GetDataDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetDataDistribution_Call { func (_e *MockCluster_Expecter) GetDataDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetDataDistribution_Call {
return &MockCluster_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", ctx, nodeID, req)} return &MockCluster_GetDataDistribution_Call{Call: _e.mock.On("GetDataDistribution", ctx, nodeID, req)}
} }
@ -151,9 +151,9 @@ type MockCluster_GetMetrics_Call struct {
} }
// GetMetrics is a helper method to define mock.On call // GetMetrics is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *milvuspb.GetMetricsRequest // - req *milvuspb.GetMetricsRequest
func (_e *MockCluster_Expecter) GetMetrics(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetMetrics_Call { func (_e *MockCluster_Expecter) GetMetrics(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_GetMetrics_Call {
return &MockCluster_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, nodeID, req)} return &MockCluster_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, nodeID, req)}
} }
@ -170,6 +170,54 @@ func (_c *MockCluster_GetMetrics_Call) Return(_a0 *milvuspb.GetMetricsResponse,
return _c return _c
} }
// LoadPartitions provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) LoadPartitions(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, int64, *querypb.LoadPartitionsRequest) *commonpb.Status); ok {
r0 = rf(ctx, nodeID, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, *querypb.LoadPartitionsRequest) error); ok {
r1 = rf(ctx, nodeID, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCluster_LoadPartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'LoadPartitions'
type MockCluster_LoadPartitions_Call struct {
*mock.Call
}
// LoadPartitions is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
// - req *querypb.LoadPartitionsRequest
func (_e *MockCluster_Expecter) LoadPartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadPartitions_Call {
return &MockCluster_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, nodeID, req)}
}
func (_c *MockCluster_LoadPartitions_Call) Run(run func(ctx context.Context, nodeID int64, req *querypb.LoadPartitionsRequest)) *MockCluster_LoadPartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(*querypb.LoadPartitionsRequest))
})
return _c
}
func (_c *MockCluster_LoadPartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockCluster_LoadPartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// LoadSegments provides a mock function with given fields: ctx, nodeID, req // LoadSegments provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { func (_m *MockCluster) LoadSegments(ctx context.Context, nodeID int64, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req) ret := _m.Called(ctx, nodeID, req)
@ -199,9 +247,9 @@ type MockCluster_LoadSegments_Call struct {
} }
// LoadSegments is a helper method to define mock.On call // LoadSegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.LoadSegmentsRequest // - req *querypb.LoadSegmentsRequest
func (_e *MockCluster_Expecter) LoadSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadSegments_Call { func (_e *MockCluster_Expecter) LoadSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_LoadSegments_Call {
return &MockCluster_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, nodeID, req)} return &MockCluster_LoadSegments_Call{Call: _e.mock.On("LoadSegments", ctx, nodeID, req)}
} }
@ -218,6 +266,54 @@ func (_c *MockCluster_LoadSegments_Call) Return(_a0 *commonpb.Status, _a1 error)
return _c return _c
} }
// ReleasePartitions provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) ReleasePartitions(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, int64, *querypb.ReleasePartitionsRequest) *commonpb.Status); ok {
r0 = rf(ctx, nodeID, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, *querypb.ReleasePartitionsRequest) error); ok {
r1 = rf(ctx, nodeID, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCluster_ReleasePartitions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReleasePartitions'
type MockCluster_ReleasePartitions_Call struct {
*mock.Call
}
// ReleasePartitions is a helper method to define mock.On call
// - ctx context.Context
// - nodeID int64
// - req *querypb.ReleasePartitionsRequest
func (_e *MockCluster_Expecter) ReleasePartitions(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleasePartitions_Call {
return &MockCluster_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, nodeID, req)}
}
func (_c *MockCluster_ReleasePartitions_Call) Run(run func(ctx context.Context, nodeID int64, req *querypb.ReleasePartitionsRequest)) *MockCluster_ReleasePartitions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64), args[2].(*querypb.ReleasePartitionsRequest))
})
return _c
}
func (_c *MockCluster_ReleasePartitions_Call) Return(_a0 *commonpb.Status, _a1 error) *MockCluster_ReleasePartitions_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// ReleaseSegments provides a mock function with given fields: ctx, nodeID, req // ReleaseSegments provides a mock function with given fields: ctx, nodeID, req
func (_m *MockCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { func (_m *MockCluster) ReleaseSegments(ctx context.Context, nodeID int64, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, nodeID, req) ret := _m.Called(ctx, nodeID, req)
@ -247,9 +343,9 @@ type MockCluster_ReleaseSegments_Call struct {
} }
// ReleaseSegments is a helper method to define mock.On call // ReleaseSegments is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.ReleaseSegmentsRequest // - req *querypb.ReleaseSegmentsRequest
func (_e *MockCluster_Expecter) ReleaseSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleaseSegments_Call { func (_e *MockCluster_Expecter) ReleaseSegments(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_ReleaseSegments_Call {
return &MockCluster_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, nodeID, req)} return &MockCluster_ReleaseSegments_Call{Call: _e.mock.On("ReleaseSegments", ctx, nodeID, req)}
} }
@ -277,7 +373,7 @@ type MockCluster_Start_Call struct {
} }
// Start is a helper method to define mock.On call // Start is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockCluster_Expecter) Start(ctx interface{}) *MockCluster_Start_Call { func (_e *MockCluster_Expecter) Start(ctx interface{}) *MockCluster_Start_Call {
return &MockCluster_Start_Call{Call: _e.mock.On("Start", ctx)} return &MockCluster_Start_Call{Call: _e.mock.On("Start", ctx)}
} }
@ -350,9 +446,9 @@ type MockCluster_SyncDistribution_Call struct {
} }
// SyncDistribution is a helper method to define mock.On call // SyncDistribution is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.SyncDistributionRequest // - req *querypb.SyncDistributionRequest
func (_e *MockCluster_Expecter) SyncDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_SyncDistribution_Call { func (_e *MockCluster_Expecter) SyncDistribution(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_SyncDistribution_Call {
return &MockCluster_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", ctx, nodeID, req)} return &MockCluster_SyncDistribution_Call{Call: _e.mock.On("SyncDistribution", ctx, nodeID, req)}
} }
@ -398,9 +494,9 @@ type MockCluster_UnsubDmChannel_Call struct {
} }
// UnsubDmChannel is a helper method to define mock.On call // UnsubDmChannel is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.UnsubDmChannelRequest // - req *querypb.UnsubDmChannelRequest
func (_e *MockCluster_Expecter) UnsubDmChannel(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_UnsubDmChannel_Call { func (_e *MockCluster_Expecter) UnsubDmChannel(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_UnsubDmChannel_Call {
return &MockCluster_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", ctx, nodeID, req)} return &MockCluster_UnsubDmChannel_Call{Call: _e.mock.On("UnsubDmChannel", ctx, nodeID, req)}
} }
@ -446,9 +542,9 @@ type MockCluster_WatchDmChannels_Call struct {
} }
// WatchDmChannels is a helper method to define mock.On call // WatchDmChannels is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - nodeID int64 // - nodeID int64
// - req *querypb.WatchDmChannelsRequest // - req *querypb.WatchDmChannelsRequest
func (_e *MockCluster_Expecter) WatchDmChannels(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_WatchDmChannels_Call { func (_e *MockCluster_Expecter) WatchDmChannels(ctx interface{}, nodeID interface{}, req interface{}) *MockCluster_WatchDmChannels_Call {
return &MockCluster_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, nodeID, req)} return &MockCluster_WatchDmChannels_Call{Call: _e.mock.On("WatchDmChannels", ctx, nodeID, req)}
} }

View File

@ -247,7 +247,7 @@ func (ex *Executor) loadSegment(task *SegmentTask, step int) error {
log.Warn("failed to get schema of collection", zap.Error(err)) log.Warn("failed to get schema of collection", zap.Error(err))
return err return err
} }
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, ex.broker, task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID())
if err != nil { if err != nil {
log.Warn("failed to get partitions of collection", zap.Error(err)) log.Warn("failed to get partitions of collection", zap.Error(err))
return err return err
@ -388,7 +388,7 @@ func (ex *Executor) subDmChannel(task *ChannelTask, step int) error {
log.Warn("failed to get schema of collection") log.Warn("failed to get schema of collection")
return err return err
} }
partitions, err := utils.GetPartitions(ex.meta.CollectionManager, ex.broker, task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID())
if err != nil { if err != nil {
log.Warn("failed to get partitions of collection") log.Warn("failed to get partitions of collection")
return err return err

View File

@ -185,8 +185,6 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
Return(&schemapb.CollectionSchema{ Return(&schemapb.CollectionSchema{
Name: "TestSubscribeChannelTask", Name: "TestSubscribeChannelTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).
Return([]int64{100, 101}, nil)
channels := make([]*datapb.VchannelInfo, 0, len(suite.subChannels)) channels := make([]*datapb.VchannelInfo, 0, len(suite.subChannels))
for _, channel := range suite.subChannels { for _, channel := range suite.subChannels {
channels = append(channels, &datapb.VchannelInfo{ channels = append(channels, &datapb.VchannelInfo{
@ -234,6 +232,7 @@ func (suite *TaskSuite) TestSubscribeChannelTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0) suite.AssertTaskNum(0, len(suite.subChannels), len(suite.subChannels), 0)
@ -293,6 +292,7 @@ func (suite *TaskSuite) TestUnsubscribeChannelTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(dmChannels, nil, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
@ -333,7 +333,6 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestLoadSegmentTask", Name: "TestLoadSegmentTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -374,6 +373,7 @@ func (suite *TaskSuite) TestLoadSegmentTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
@ -417,7 +417,6 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestLoadSegmentTask", Name: "TestLoadSegmentTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -455,6 +454,7 @@ func (suite *TaskSuite) TestLoadSegmentTaskFailed() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
@ -610,7 +610,6 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestMoveSegmentTask", Name: "TestMoveSegmentTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.moveSegments { for _, segment := range suite.moveSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -665,6 +664,7 @@ func (suite *TaskSuite) TestMoveSegmentTask() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return([]*datapb.VchannelInfo{vchannel}, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
suite.target.UpdateCollectionCurrentTarget(suite.collection, int64(1)) suite.target.UpdateCollectionCurrentTarget(suite.collection, int64(1))
@ -709,7 +709,6 @@ func (suite *TaskSuite) TestTaskCanceled() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestSubscribeChannelTask", Name: "TestSubscribeChannelTask",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -752,6 +751,7 @@ func (suite *TaskSuite) TestTaskCanceled() {
} }
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum) suite.AssertTaskNum(0, segmentsNum, 0, segmentsNum)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{partition}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, partition).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, partition) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, partition)
@ -787,7 +787,6 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{ suite.broker.EXPECT().GetCollectionSchema(mock.Anything, suite.collection).Return(&schemapb.CollectionSchema{
Name: "TestSegmentTaskStale", Name: "TestSegmentTaskStale",
}, nil) }, nil)
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{100, 101}, nil)
for _, segment := range suite.loadSegments { for _, segment := range suite.loadSegments {
suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{ suite.broker.EXPECT().GetSegmentInfo(mock.Anything, segment).Return(&datapb.GetSegmentInfoResponse{Infos: []*datapb.SegmentInfo{
{ {
@ -829,6 +828,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)
@ -856,6 +856,10 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
InsertChannel: channel.GetChannelName(), InsertChannel: channel.GetChannelName(),
}) })
} }
bakExpectations := suite.broker.ExpectedCalls
suite.broker.AssertExpectations(suite.T())
suite.broker.ExpectedCalls = suite.broker.ExpectedCalls[:0]
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{2}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(2)).Return(nil, segmentInfos, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(2)).Return(nil, segmentInfos, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(2)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(2))
suite.dispatchAndWait(targetNode) suite.dispatchAndWait(targetNode)
@ -870,6 +874,7 @@ func (suite *TaskSuite) TestSegmentTaskStale() {
suite.NoError(task.Err()) suite.NoError(task.Err())
} }
} }
suite.broker.ExpectedCalls = bakExpectations
} }
func (suite *TaskSuite) TestChannelTaskReplace() { func (suite *TaskSuite) TestChannelTaskReplace() {
@ -1060,6 +1065,7 @@ func (suite *TaskSuite) TestNoExecutor() {
err = suite.scheduler.Add(task) err = suite.scheduler.Add(task)
suite.NoError(err) suite.NoError(err)
} }
suite.broker.EXPECT().GetPartitions(mock.Anything, suite.collection).Return([]int64{1}, nil)
suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil) suite.broker.EXPECT().GetRecoveryInfo(mock.Anything, suite.collection, int64(1)).Return(nil, segments, nil)
suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1)) suite.target.UpdateCollectionNextTargetWithPartitions(suite.collection, int64(1))
segmentsNum := len(suite.loadSegments) segmentsNum := len(suite.loadSegments)

View File

@ -17,7 +17,6 @@
package utils package utils
import ( import (
"context"
"fmt" "fmt"
"math/rand" "math/rand"
"sort" "sort"
@ -51,18 +50,15 @@ func GetReplicaNodesInfo(replicaMgr *meta.ReplicaManager, nodeMgr *session.NodeM
return nodes return nodes
} }
func GetPartitions(collectionMgr *meta.CollectionManager, broker meta.Broker, collectionID int64) ([]int64, error) { func GetPartitions(collectionMgr *meta.CollectionManager, collectionID int64) ([]int64, error) {
collection := collectionMgr.GetCollection(collectionID) collection := collectionMgr.GetCollection(collectionID)
if collection != nil { if collection != nil {
partitions, err := broker.GetPartitions(context.Background(), collectionID) partitions := collectionMgr.GetPartitionsByCollection(collectionID)
return partitions, err if partitions != nil {
} return lo.Map(partitions, func(partition *meta.Partition, i int) int64 {
return partition.PartitionID
partitions := collectionMgr.GetPartitionsByCollection(collectionID) }), nil
if partitions != nil { }
return lo.Map(partitions, func(partition *meta.Partition, i int) int64 {
return partition.PartitionID
}), nil
} }
// todo(yah01): replace this error with a defined error // todo(yah01): replace this error with a defined error

View File

@ -248,7 +248,7 @@ func TestFlowGraphDeleteNode_operate(t *testing.T) {
}, },
} }
msg := []flowgraph.Msg{&dMsg} msg := []flowgraph.Msg{&dMsg}
assert.Panics(t, func() { deleteNode.Operate(msg) }) deleteNode.Operate(msg)
}) })
t.Run("test partition not exist", func(t *testing.T) { t.Run("test partition not exist", func(t *testing.T) {

View File

@ -139,12 +139,12 @@ func (fddNode *filterDeleteNode) filterInvalidDeleteMessage(msg *msgstream.Delet
return nil, nil return nil, nil
} }
if loadType == loadTypePartition { //if loadType == loadTypePartition {
if !fddNode.metaReplica.hasPartition(msg.PartitionID) { // if !fddNode.metaReplica.hasPartition(msg.PartitionID) {
// filter out msg which not belongs to the loaded partitions // // filter out msg which not belongs to the loaded partitions
return nil, nil // return nil, nil
} // }
} //}
return msg, nil return msg, nil
} }

View File

@ -89,7 +89,7 @@ func TestFlowGraphFilterDeleteNode_filterInvalidDeleteMessage(t *testing.T) {
res, err := fg.filterInvalidDeleteMessage(msg, loadTypePartition) res, err := fg.filterInvalidDeleteMessage(msg, loadTypePartition)
assert.NoError(t, err) assert.NoError(t, err)
assert.Nil(t, res) assert.NotNil(t, res)
}) })
} }

View File

@ -162,12 +162,12 @@ func (fdmNode *filterDmNode) filterInvalidDeleteMessage(msg *msgstream.DeleteMsg
return nil, nil return nil, nil
} }
if loadType == loadTypePartition { //if loadType == loadTypePartition {
if !fdmNode.metaReplica.hasPartition(msg.PartitionID) { // if !fdmNode.metaReplica.hasPartition(msg.PartitionID) {
// filter out msg which not belongs to the loaded partitions // // filter out msg which not belongs to the loaded partitions
return nil, nil // return nil, nil
} // }
} //}
return msg, nil return msg, nil
} }
@ -198,12 +198,12 @@ func (fdmNode *filterDmNode) filterInvalidInsertMessage(msg *msgstream.InsertMsg
return nil, nil return nil, nil
} }
if loadType == loadTypePartition { //if loadType == loadTypePartition {
if !fdmNode.metaReplica.hasPartition(msg.PartitionID) { // if !fdmNode.metaReplica.hasPartition(msg.PartitionID) {
// filter out msg which not belongs to the loaded partitions // // filter out msg which not belongs to the loaded partitions
return nil, nil // return nil, nil
} // }
} //}
// Check if the segment is in excluded segments, // Check if the segment is in excluded segments,
// messages after seekPosition may contain the redundant data from flushed slice of segment, // messages after seekPosition may contain the redundant data from flushed slice of segment,

View File

@ -71,18 +71,6 @@ func TestFlowGraphFilterDmNode_filterInvalidInsertMessage(t *testing.T) {
fg.collectionID = defaultCollectionID fg.collectionID = defaultCollectionID
}) })
t.Run("test no partition", func(t *testing.T) {
msg, err := genSimpleInsertMsg(schema, defaultMsgLength)
assert.NoError(t, err)
msg.PartitionID = UniqueID(1000)
fg, err := getFilterDMNode()
assert.NoError(t, err)
res, err := fg.filterInvalidInsertMessage(msg, loadTypePartition)
assert.NoError(t, err)
assert.Nil(t, res)
})
t.Run("test not target collection", func(t *testing.T) { t.Run("test not target collection", func(t *testing.T) {
msg, err := genSimpleInsertMsg(schema, defaultMsgLength) msg, err := genSimpleInsertMsg(schema, defaultMsgLength)
assert.NoError(t, err) assert.NoError(t, err)
@ -162,17 +150,6 @@ func TestFlowGraphFilterDmNode_filterInvalidDeleteMessage(t *testing.T) {
assert.NotNil(t, res) assert.NotNil(t, res)
}) })
t.Run("test delete no partition", func(t *testing.T) {
msg := genDeleteMsg(defaultCollectionID, schemapb.DataType_Int64, defaultDelLength)
msg.PartitionID = UniqueID(1000)
fg, err := getFilterDMNode()
assert.NoError(t, err)
res, err := fg.filterInvalidDeleteMessage(msg, loadTypePartition)
assert.NoError(t, err)
assert.Nil(t, res)
})
t.Run("test delete not target collection", func(t *testing.T) { t.Run("test delete not target collection", func(t *testing.T) {
msg := genDeleteMsg(defaultCollectionID, schemapb.DataType_Int64, defaultDelLength) msg := genDeleteMsg(defaultCollectionID, schemapb.DataType_Int64, defaultDelLength)
fg, err := getFilterDMNode() fg, err := getFilterDMNode()

View File

@ -314,15 +314,6 @@ func processDeleteMessages(replica ReplicaInterface, segType segmentType, msg *m
var err error var err error
if msg.PartitionID != -1 { if msg.PartitionID != -1 {
partitionIDs = []UniqueID{msg.GetPartitionID()} partitionIDs = []UniqueID{msg.GetPartitionID()}
} else {
partitionIDs, err = replica.getPartitionIDs(msg.GetCollectionID())
if err != nil {
log.Warn("the collection has been released, ignore it",
zap.Int64("collectionID", msg.GetCollectionID()),
zap.Error(err),
)
return err
}
} }
var resultSegmentIDs []UniqueID var resultSegmentIDs []UniqueID
resultSegmentIDs, err = replica.getSegmentIDsByVChannel(partitionIDs, vchannelName, segType) resultSegmentIDs, err = replica.getSegmentIDsByVChannel(partitionIDs, vchannelName, segType)

View File

@ -288,7 +288,7 @@ func TestFlowGraphInsertNode_operate(t *testing.T) {
}, },
} }
msg := []flowgraph.Msg{&iMsg} msg := []flowgraph.Msg{&iMsg}
assert.Panics(t, func() { insertNode.Operate(msg) }) insertNode.Operate(msg)
}) })
t.Run("test partition not exist", func(t *testing.T) { t.Run("test partition not exist", func(t *testing.T) {

View File

@ -566,10 +566,10 @@ func (node *QueryNode) ReleaseCollection(ctx context.Context, in *querypb.Releas
return status, nil return status, nil
} }
// ReleasePartitions clears all data related to this partition on the querynode func (node *QueryNode) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { nodeID := node.session.ServerID
if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) { if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", node.GetSession().ServerID) err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
@ -578,35 +578,58 @@ func (node *QueryNode) ReleasePartitions(ctx context.Context, in *querypb.Releas
} }
defer node.lifetime.Done() defer node.lifetime.Done()
dct := &releasePartitionsTask{ // check target matches
baseTask: baseTask{ if req.GetBase().GetTargetID() != nodeID {
ctx: ctx, status := &commonpb.Status{
done: make(chan error), ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
}, Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
req: in, }
node: node, return status, nil
} }
err := node.scheduler.queue.Enqueue(dct) log.Ctx(ctx).With(zap.Int64("colID", req.GetCollectionID()), zap.Int64s("partIDs", req.GetPartitionIDs()))
if err != nil { log.Info("loading partitions")
for _, part := range req.GetPartitionIDs() {
err := node.metaReplica.addPartition(req.GetCollectionID(), part)
if err != nil {
log.Warn(err.Error())
}
}
log.Info("load partitions done")
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
}
return status, nil
}
// ReleasePartitions clears all data related to this partition on the querynode
func (node *QueryNode) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
nodeID := node.session.ServerID
if !node.lifetime.Add(commonpbutil.IsHealthyOrStopping) {
err := fmt.Errorf("query node %d is not ready", nodeID)
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(), Reason: err.Error(),
} }
log.Warn(err.Error())
return status, nil return status, nil
} }
log.Info("releasePartitionsTask Enqueue done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("partitionIDs", in.PartitionIDs)) defer node.lifetime.Done()
func() { // check target matches
err = dct.WaitToFinish() if req.GetBase().GetTargetID() != nodeID {
if err != nil { status := &commonpb.Status{
log.Warn(err.Error()) ErrorCode: commonpb.ErrorCode_NodeIDNotMatch,
return Reason: common.WrapNodeIDNotMatchMsg(req.GetBase().GetTargetID(), nodeID),
} }
log.Info("releasePartitionsTask WaitToFinish done", zap.Int64("collectionID", in.CollectionID), zap.Int64s("partitionIDs", in.PartitionIDs)) return status, nil
}() }
log.Ctx(ctx).With(zap.Int64("colID", req.GetCollectionID()), zap.Int64s("partIDs", req.GetPartitionIDs()))
log.Info("releasing partitions")
for _, part := range req.GetPartitionIDs() {
node.metaReplica.removePartition(part)
}
log.Info("release partitions done")
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
} }

View File

@ -23,6 +23,10 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
@ -33,10 +37,8 @@ import (
"github.com/milvus-io/milvus/internal/util/conc" "github.com/milvus-io/milvus/internal/util/conc"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
) )
func TestImpl_GetComponentStates(t *testing.T) { func TestImpl_GetComponentStates(t *testing.T) {
@ -360,6 +362,36 @@ func TestImpl_ReleaseCollection(t *testing.T) {
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
} }
func TestImpl_LoadPartitions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
node, err := genSimpleQueryNode(ctx)
assert.NoError(t, err)
req := &queryPb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
TargetID: paramtable.GetNodeID(),
},
CollectionID: defaultCollectionID,
PartitionIDs: []UniqueID{defaultPartitionID},
}
status, err := node.LoadPartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
node.UpdateStateCode(commonpb.StateCode_Abnormal)
status, err = node.LoadPartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
node.UpdateStateCode(commonpb.StateCode_Healthy)
req.Base.TargetID = -1
status, err = node.LoadPartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, status.ErrorCode)
}
func TestImpl_ReleasePartitions(t *testing.T) { func TestImpl_ReleasePartitions(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
@ -368,8 +400,9 @@ func TestImpl_ReleasePartitions(t *testing.T) {
req := &queryPb.ReleasePartitionsRequest{ req := &queryPb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_WatchQueryChannels, MsgType: commonpb.MsgType_WatchQueryChannels,
MsgID: rand.Int63(), MsgID: rand.Int63(),
TargetID: paramtable.GetNodeID(),
}, },
NodeID: 0, NodeID: 0,
CollectionID: defaultCollectionID, CollectionID: defaultCollectionID,
@ -384,6 +417,12 @@ func TestImpl_ReleasePartitions(t *testing.T) {
status, err = node.ReleasePartitions(ctx, req) status, err = node.ReleasePartitions(ctx, req)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode) assert.Equal(t, commonpb.ErrorCode_UnexpectedError, status.ErrorCode)
node.UpdateStateCode(commonpb.StateCode_Healthy)
req.Base.TargetID = -1
status, err = node.ReleasePartitions(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_NodeIDNotMatch, status.ErrorCode)
} }
func TestImpl_GetSegmentInfo(t *testing.T) { func TestImpl_GetSegmentInfo(t *testing.T) {

View File

@ -458,14 +458,14 @@ func (replica *metaReplica) removePartitionPrivate(partitionID UniqueID) error {
} }
// delete segments // delete segments
ids, _ := partition.getSegmentIDs(segmentTypeGrowing) //ids, _ := partition.getSegmentIDs(segmentTypeGrowing)
for _, segmentID := range ids { //for _, segmentID := range ids {
replica.removeSegmentPrivate(segmentID, segmentTypeGrowing) // replica.removeSegmentPrivate(segmentID, segmentTypeGrowing)
} //}
ids, _ = partition.getSegmentIDs(segmentTypeSealed) //ids, _ = partition.getSegmentIDs(segmentTypeSealed)
for _, segmentID := range ids { //for _, segmentID := range ids {
replica.removeSegmentPrivate(segmentID, segmentTypeSealed) // replica.removeSegmentPrivate(segmentID, segmentTypeSealed)
} //}
collection.removePartitionID(partitionID) collection.removePartitionID(partitionID)
delete(replica.partitions, partitionID) delete(replica.partitions, partitionID)
@ -589,10 +589,6 @@ func (replica *metaReplica) addSegment(segmentID UniqueID, partitionID UniqueID,
// addSegmentPrivate is private function in collectionReplica, to add a new segment to collectionReplica // addSegmentPrivate is private function in collectionReplica, to add a new segment to collectionReplica
func (replica *metaReplica) addSegmentPrivate(segment *Segment) error { func (replica *metaReplica) addSegmentPrivate(segment *Segment) error {
segID := segment.segmentID segID := segment.segmentID
partition, err := replica.getPartitionByIDPrivate(segment.partitionID)
if err != nil {
return err
}
segType := segment.getType() segType := segment.getType()
ok, err := replica.hasSegmentPrivate(segID, segType) ok, err := replica.hasSegmentPrivate(segID, segType)
@ -603,12 +599,16 @@ func (replica *metaReplica) addSegmentPrivate(segment *Segment) error {
return fmt.Errorf("segment has been existed, "+ return fmt.Errorf("segment has been existed, "+
"segmentID = %d, collectionID = %d, segmentType = %s", segID, segment.collectionID, segType.String()) "segmentID = %d, collectionID = %d, segmentType = %s", segID, segment.collectionID, segType.String())
} }
partition.addSegmentID(segID, segType)
switch segType { switch segType {
case segmentTypeGrowing: case segmentTypeGrowing:
replica.growingSegments[segID] = segment replica.growingSegments[segID] = segment
case segmentTypeSealed: case segmentTypeSealed:
partition, err := replica.getPartitionByIDPrivate(segment.partitionID)
if err != nil {
return err
}
partition.addSegmentID(segID, segType)
replica.sealedSegments[segID] = segment replica.sealedSegments[segID] = segment
default: default:
return fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segID, segType.String()) return fmt.Errorf("unexpected segment type, segmentID = %d, segmentType = %s", segID, segType.String())

View File

@ -200,6 +200,7 @@ func TestStreaming_search(t *testing.T) {
collection, err := streaming.getCollectionByID(defaultCollectionID) collection, err := streaming.getCollectionByID(defaultCollectionID)
assert.NoError(t, err) assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ) searchReq, err := genSearchPlanAndRequests(collection, IndexFaissIDMap, defaultNQ)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -76,6 +76,10 @@ func TestHistorical_statistic(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
collection, err := his.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
@ -153,6 +157,10 @@ func TestStreaming_statistics(t *testing.T) {
streaming, err := genSimpleReplicaWithGrowingSegment() streaming, err := genSimpleReplicaWithGrowingSegment()
assert.NoError(t, err) assert.NoError(t, err)
collection, err := streaming.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
err = streaming.removePartition(defaultPartitionID) err = streaming.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)

View File

@ -86,6 +86,9 @@ func TestQueryShardHistorical_validateSegmentIDs(t *testing.T) {
t.Run("test validate after partition release", func(t *testing.T) { t.Run("test validate after partition release", func(t *testing.T) {
his, err := genSimpleReplicaWithSealSegment(ctx) his, err := genSimpleReplicaWithSealSegment(ctx)
assert.NoError(t, err) assert.NoError(t, err)
collection, err := his.getCollectionByID(defaultCollectionID)
assert.NoError(t, err)
collection.setLoadType(loadTypeCollection)
err = his.removePartition(defaultPartitionID) err = his.removePartition(defaultPartitionID)
assert.NoError(t, err) assert.NoError(t, err)
_, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID}) _, _, err = validateOnHistoricalReplica(context.TODO(), his, defaultCollectionID, []UniqueID{}, []UniqueID{defaultSegmentID})

View File

@ -48,6 +48,8 @@ type watchInfo struct {
// Broker communicates with other components. // Broker communicates with other components.
type Broker interface { type Broker interface {
ReleaseCollection(ctx context.Context, collectionID UniqueID) error ReleaseCollection(ctx context.Context, collectionID UniqueID) error
ReleasePartitions(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error
SyncNewCreatedPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error
GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error)
WatchChannels(ctx context.Context, info *watchInfo) error WatchChannels(ctx context.Context, info *watchInfo) error
@ -93,6 +95,49 @@ func (b *ServerBroker) ReleaseCollection(ctx context.Context, collectionID Uniqu
return nil return nil
} }
func (b *ServerBroker) ReleasePartitions(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
if len(partitionIDs) == 0 {
return nil
}
log := log.Ctx(ctx).With(zap.Int64("collection", collectionID), zap.Int64s("partitionIDs", partitionIDs))
log.Info("releasing partitions")
resp, err := b.s.queryCoord.ReleasePartitions(ctx, &querypb.ReleasePartitionsRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ReleasePartitions)),
CollectionID: collectionID,
PartitionIDs: partitionIDs,
})
if err != nil {
return err
}
if resp.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("release partition failed, reason: %s", resp.GetReason())
}
log.Info("release partitions done")
return nil
}
func (b *ServerBroker) SyncNewCreatedPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error {
log := log.Ctx(ctx).With(zap.Int64("collection", collectionID), zap.Int64("partitionID", partitionID))
log.Info("begin to sync new partition")
resp, err := b.s.queryCoord.SyncNewCreatedPartition(ctx, &querypb.SyncNewCreatedPartitionRequest{
Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_ReleasePartitions)),
CollectionID: collectionID,
PartitionID: partitionID,
})
if err != nil {
return err
}
if resp.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf("sync new partition failed, reason: %s", resp.GetReason())
}
log.Info("sync new partition done")
return nil
}
func (b *ServerBroker) GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) { func (b *ServerBroker) GetQuerySegmentInfo(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) {
resp, err := b.s.queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{ resp, err := b.s.queryCoord.GetSegmentInfo(ctx, &querypb.GetSegmentInfoRequest{
Base: commonpbutil.NewMsgBase( Base: commonpbutil.NewMsgBase(

View File

@ -73,7 +73,7 @@ func (t *createPartitionTask) Execute(ctx context.Context) error {
PartitionCreatedTimestamp: t.GetTs(), PartitionCreatedTimestamp: t.GetTs(),
Extra: nil, Extra: nil,
CollectionID: t.collMeta.CollectionID, CollectionID: t.collMeta.CollectionID,
State: pb.PartitionState_PartitionCreated, State: pb.PartitionState_PartitionCreating,
} }
undoTask := newBaseUndoTask(t.core.stepExecutor) undoTask := newBaseUndoTask(t.core.stepExecutor)
@ -88,5 +88,23 @@ func (t *createPartitionTask) Execute(ctx context.Context) error {
partition: partition, partition: partition,
}, &nullStep{}) // adding partition is atomic enough. }, &nullStep{}) // adding partition is atomic enough.
undoTask.AddStep(&syncNewCreatedPartitionStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionID: partID,
}, &releasePartitionsStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionIDs: []int64{partID},
})
undoTask.AddStep(&changePartitionStateStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionID: partID,
state: pb.PartitionState_PartitionCreated,
ts: t.GetTs(),
}, &nullStep{})
return undoTask.Execute(ctx) return undoTask.Execute(ctx)
} }

View File

@ -20,14 +20,13 @@ import (
"context" "context"
"testing" "testing"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/proto/etcdpb"
"github.com/milvus-io/milvus/internal/util/funcutil"
) )
func Test_createPartitionTask_Prepare(t *testing.T) { func Test_createPartitionTask_Prepare(t *testing.T) {
@ -147,7 +146,14 @@ func Test_createPartitionTask_Execute(t *testing.T) {
meta.AddPartitionFunc = func(ctx context.Context, partition *model.Partition) error { meta.AddPartitionFunc = func(ctx context.Context, partition *model.Partition) error {
return nil return nil
} }
core := newTestCore(withValidIDAllocator(), withValidProxyManager(), withMeta(meta)) meta.ChangePartitionStateFunc = func(ctx context.Context, collectionID UniqueID, partitionID UniqueID, state etcdpb.PartitionState, ts Timestamp) error {
return nil
}
b := newMockBroker()
b.SyncNewCreatedPartitionFunc = func(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error {
return nil
}
core := newTestCore(withValidIDAllocator(), withValidProxyManager(), withMeta(meta), withBroker(b))
task := &createPartitionTask{ task := &createPartitionTask{
baseTask: baseTask{core: core}, baseTask: baseTask{core: core},
collMeta: coll, collMeta: coll,

View File

@ -85,7 +85,12 @@ func (t *dropPartitionTask) Execute(ctx context.Context) error {
ts: t.GetTs(), ts: t.GetTs(),
}) })
// TODO: release partition when query coord is ready. redoTask.AddAsyncStep(&releasePartitionsStep{
baseStep: baseStep{core: t.core},
collectionID: t.collMeta.CollectionID,
partitionIDs: []int64{partID},
})
redoTask.AddAsyncStep(&deletePartitionDataStep{ redoTask.AddAsyncStep(&deletePartitionDataStep{
baseStep: baseStep{core: t.core}, baseStep: baseStep{core: t.core},
pchans: t.collMeta.PhysicalChannelNames, pchans: t.collMeta.PhysicalChannelNames,

View File

@ -177,6 +177,9 @@ func Test_dropPartitionTask_Execute(t *testing.T) {
broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error { broker.DropCollectionIndexFunc = func(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return nil return nil
} }
broker.ReleasePartitionsFunc = func(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
return nil
}
core := newTestCore( core := newTestCore(
withValidProxyManager(), withValidProxyManager(),

View File

@ -512,7 +512,7 @@ func (mt *MetaTable) AddPartition(ctx context.Context, partition *model.Partitio
if !ok || !coll.Available() { if !ok || !coll.Available() {
return fmt.Errorf("collection not exists: %d", partition.CollectionID) return fmt.Errorf("collection not exists: %d", partition.CollectionID)
} }
if partition.State != pb.PartitionState_PartitionCreated { if partition.State != pb.PartitionState_PartitionCreating {
return fmt.Errorf("partition state is not created, collection: %d, partition: %d, state: %s", partition.CollectionID, partition.PartitionID, partition.State) return fmt.Errorf("partition state is not created, collection: %d, partition: %d, state: %s", partition.CollectionID, partition.PartitionID, partition.State)
} }
if err := mt.catalog.CreatePartition(ctx, partition, partition.PartitionCreatedTimestamp); err != nil { if err := mt.catalog.CreatePartition(ctx, partition, partition.PartitionCreatedTimestamp); err != nil {

View File

@ -892,7 +892,7 @@ func TestMetaTable_AddPartition(t *testing.T) {
100: {Name: "test", CollectionID: 100}, 100: {Name: "test", CollectionID: 100},
}, },
} }
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreated}) err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreating})
assert.Error(t, err) assert.Error(t, err)
}) })
@ -909,7 +909,7 @@ func TestMetaTable_AddPartition(t *testing.T) {
100: {Name: "test", CollectionID: 100}, 100: {Name: "test", CollectionID: 100},
}, },
} }
err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreated}) err := meta.AddPartition(context.TODO(), &model.Partition{CollectionID: 100, State: pb.PartitionState_PartitionCreating})
assert.NoError(t, err) assert.NoError(t, err)
}) })
} }

View File

@ -500,12 +500,20 @@ func withValidQueryCoord() Opt {
succStatus(), nil, succStatus(), nil,
) )
qc.EXPECT().ReleasePartitions(mock.Anything, mock.Anything).Return(
succStatus(), nil,
)
qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return( qc.EXPECT().GetSegmentInfo(mock.Anything, mock.Anything).Return(
&querypb.GetSegmentInfoResponse{ &querypb.GetSegmentInfoResponse{
Status: succStatus(), Status: succStatus(),
}, nil, }, nil,
) )
qc.EXPECT().SyncNewCreatedPartition(mock.Anything, mock.Anything).Return(
succStatus(), nil,
)
return withQueryCoord(qc) return withQueryCoord(qc)
} }
@ -779,8 +787,10 @@ func withMetricsCacheManager() Opt {
type mockBroker struct { type mockBroker struct {
Broker Broker
ReleaseCollectionFunc func(ctx context.Context, collectionID UniqueID) error ReleaseCollectionFunc func(ctx context.Context, collectionID UniqueID) error
GetQuerySegmentInfoFunc func(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error) ReleasePartitionsFunc func(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error
SyncNewCreatedPartitionFunc func(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error
GetQuerySegmentInfoFunc func(ctx context.Context, collectionID int64, segIDs []int64) (retResp *querypb.GetSegmentInfoResponse, retErr error)
WatchChannelsFunc func(ctx context.Context, info *watchInfo) error WatchChannelsFunc func(ctx context.Context, info *watchInfo) error
UnwatchChannelsFunc func(ctx context.Context, info *watchInfo) error UnwatchChannelsFunc func(ctx context.Context, info *watchInfo) error
@ -814,6 +824,14 @@ func (b mockBroker) ReleaseCollection(ctx context.Context, collectionID UniqueID
return b.ReleaseCollectionFunc(ctx, collectionID) return b.ReleaseCollectionFunc(ctx, collectionID)
} }
func (b mockBroker) ReleasePartitions(ctx context.Context, collectionID UniqueID, partitionIDs ...UniqueID) error {
return b.ReleasePartitionsFunc(ctx, collectionID)
}
func (b mockBroker) SyncNewCreatedPartition(ctx context.Context, collectionID UniqueID, partitionID UniqueID) error {
return b.SyncNewCreatedPartitionFunc(ctx, collectionID, partitionID)
}
func (b mockBroker) DropCollectionIndex(ctx context.Context, collID UniqueID, partIDs []UniqueID) error { func (b mockBroker) DropCollectionIndex(ctx context.Context, collID UniqueID, partIDs []UniqueID) error {
return b.DropCollectionIndexFunc(ctx, collID, partIDs) return b.DropCollectionIndexFunc(ctx, collID, partIDs)
} }

View File

@ -273,6 +273,44 @@ func (s *releaseCollectionStep) Weight() stepPriority {
return stepPriorityUrgent return stepPriorityUrgent
} }
type releasePartitionsStep struct {
baseStep
collectionID UniqueID
partitionIDs []UniqueID
}
func (s *releasePartitionsStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.ReleasePartitions(ctx, s.collectionID, s.partitionIDs...)
return nil, err
}
func (s *releasePartitionsStep) Desc() string {
return fmt.Sprintf("release partitions, collectionID=%d, partitionIDs=%v", s.collectionID, s.partitionIDs)
}
func (s *releasePartitionsStep) Weight() stepPriority {
return stepPriorityUrgent
}
type syncNewCreatedPartitionStep struct {
baseStep
collectionID UniqueID
partitionID UniqueID
}
func (s *syncNewCreatedPartitionStep) Execute(ctx context.Context) ([]nestedStep, error) {
err := s.core.broker.SyncNewCreatedPartition(ctx, s.collectionID, s.partitionID)
return nil, err
}
func (s *syncNewCreatedPartitionStep) Desc() string {
return fmt.Sprintf("sync new partition, collectionID=%d, partitionID=%d", s.partitionID, s.partitionID)
}
func (s *syncNewCreatedPartitionStep) Weight() stepPriority {
return stepPriorityUrgent
}
type dropIndexStep struct { type dropIndexStep struct {
baseStep baseStep
collID UniqueID collID UniqueID

View File

@ -1,4 +1,4 @@
// Code generated by mockery v2.16.0. DO NOT EDIT. // Code generated by mockery v2.14.0. DO NOT EDIT.
package types package types
@ -59,8 +59,8 @@ type MockQueryCoord_CheckHealth_Call struct {
} }
// CheckHealth is a helper method to define mock.On call // CheckHealth is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *milvuspb.CheckHealthRequest // - req *milvuspb.CheckHealthRequest
func (_e *MockQueryCoord_Expecter) CheckHealth(ctx interface{}, req interface{}) *MockQueryCoord_CheckHealth_Call { func (_e *MockQueryCoord_Expecter) CheckHealth(ctx interface{}, req interface{}) *MockQueryCoord_CheckHealth_Call {
return &MockQueryCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, req)} return &MockQueryCoord_CheckHealth_Call{Call: _e.mock.On("CheckHealth", ctx, req)}
} }
@ -106,8 +106,8 @@ type MockQueryCoord_CreateResourceGroup_Call struct {
} }
// CreateResourceGroup is a helper method to define mock.On call // CreateResourceGroup is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *milvuspb.CreateResourceGroupRequest // - req *milvuspb.CreateResourceGroupRequest
func (_e *MockQueryCoord_Expecter) CreateResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_CreateResourceGroup_Call { func (_e *MockQueryCoord_Expecter) CreateResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_CreateResourceGroup_Call {
return &MockQueryCoord_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", ctx, req)} return &MockQueryCoord_CreateResourceGroup_Call{Call: _e.mock.On("CreateResourceGroup", ctx, req)}
} }
@ -153,8 +153,8 @@ type MockQueryCoord_DescribeResourceGroup_Call struct {
} }
// DescribeResourceGroup is a helper method to define mock.On call // DescribeResourceGroup is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.DescribeResourceGroupRequest // - req *querypb.DescribeResourceGroupRequest
func (_e *MockQueryCoord_Expecter) DescribeResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_DescribeResourceGroup_Call { func (_e *MockQueryCoord_Expecter) DescribeResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_DescribeResourceGroup_Call {
return &MockQueryCoord_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", ctx, req)} return &MockQueryCoord_DescribeResourceGroup_Call{Call: _e.mock.On("DescribeResourceGroup", ctx, req)}
} }
@ -200,8 +200,8 @@ type MockQueryCoord_DropResourceGroup_Call struct {
} }
// DropResourceGroup is a helper method to define mock.On call // DropResourceGroup is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *milvuspb.DropResourceGroupRequest // - req *milvuspb.DropResourceGroupRequest
func (_e *MockQueryCoord_Expecter) DropResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_DropResourceGroup_Call { func (_e *MockQueryCoord_Expecter) DropResourceGroup(ctx interface{}, req interface{}) *MockQueryCoord_DropResourceGroup_Call {
return &MockQueryCoord_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", ctx, req)} return &MockQueryCoord_DropResourceGroup_Call{Call: _e.mock.On("DropResourceGroup", ctx, req)}
} }
@ -247,7 +247,7 @@ type MockQueryCoord_GetComponentStates_Call struct {
} }
// GetComponentStates is a helper method to define mock.On call // GetComponentStates is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockQueryCoord_Expecter) GetComponentStates(ctx interface{}) *MockQueryCoord_GetComponentStates_Call { func (_e *MockQueryCoord_Expecter) GetComponentStates(ctx interface{}) *MockQueryCoord_GetComponentStates_Call {
return &MockQueryCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} return &MockQueryCoord_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)}
} }
@ -293,8 +293,8 @@ type MockQueryCoord_GetMetrics_Call struct {
} }
// GetMetrics is a helper method to define mock.On call // GetMetrics is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *milvuspb.GetMetricsRequest // - req *milvuspb.GetMetricsRequest
func (_e *MockQueryCoord_Expecter) GetMetrics(ctx interface{}, req interface{}) *MockQueryCoord_GetMetrics_Call { func (_e *MockQueryCoord_Expecter) GetMetrics(ctx interface{}, req interface{}) *MockQueryCoord_GetMetrics_Call {
return &MockQueryCoord_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, req)} return &MockQueryCoord_GetMetrics_Call{Call: _e.mock.On("GetMetrics", ctx, req)}
} }
@ -340,8 +340,8 @@ type MockQueryCoord_GetPartitionStates_Call struct {
} }
// GetPartitionStates is a helper method to define mock.On call // GetPartitionStates is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.GetPartitionStatesRequest // - req *querypb.GetPartitionStatesRequest
func (_e *MockQueryCoord_Expecter) GetPartitionStates(ctx interface{}, req interface{}) *MockQueryCoord_GetPartitionStates_Call { func (_e *MockQueryCoord_Expecter) GetPartitionStates(ctx interface{}, req interface{}) *MockQueryCoord_GetPartitionStates_Call {
return &MockQueryCoord_GetPartitionStates_Call{Call: _e.mock.On("GetPartitionStates", ctx, req)} return &MockQueryCoord_GetPartitionStates_Call{Call: _e.mock.On("GetPartitionStates", ctx, req)}
} }
@ -387,8 +387,8 @@ type MockQueryCoord_GetReplicas_Call struct {
} }
// GetReplicas is a helper method to define mock.On call // GetReplicas is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *milvuspb.GetReplicasRequest // - req *milvuspb.GetReplicasRequest
func (_e *MockQueryCoord_Expecter) GetReplicas(ctx interface{}, req interface{}) *MockQueryCoord_GetReplicas_Call { func (_e *MockQueryCoord_Expecter) GetReplicas(ctx interface{}, req interface{}) *MockQueryCoord_GetReplicas_Call {
return &MockQueryCoord_GetReplicas_Call{Call: _e.mock.On("GetReplicas", ctx, req)} return &MockQueryCoord_GetReplicas_Call{Call: _e.mock.On("GetReplicas", ctx, req)}
} }
@ -434,8 +434,8 @@ type MockQueryCoord_GetSegmentInfo_Call struct {
} }
// GetSegmentInfo is a helper method to define mock.On call // GetSegmentInfo is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.GetSegmentInfoRequest // - req *querypb.GetSegmentInfoRequest
func (_e *MockQueryCoord_Expecter) GetSegmentInfo(ctx interface{}, req interface{}) *MockQueryCoord_GetSegmentInfo_Call { func (_e *MockQueryCoord_Expecter) GetSegmentInfo(ctx interface{}, req interface{}) *MockQueryCoord_GetSegmentInfo_Call {
return &MockQueryCoord_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, req)} return &MockQueryCoord_GetSegmentInfo_Call{Call: _e.mock.On("GetSegmentInfo", ctx, req)}
} }
@ -481,8 +481,8 @@ type MockQueryCoord_GetShardLeaders_Call struct {
} }
// GetShardLeaders is a helper method to define mock.On call // GetShardLeaders is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.GetShardLeadersRequest // - req *querypb.GetShardLeadersRequest
func (_e *MockQueryCoord_Expecter) GetShardLeaders(ctx interface{}, req interface{}) *MockQueryCoord_GetShardLeaders_Call { func (_e *MockQueryCoord_Expecter) GetShardLeaders(ctx interface{}, req interface{}) *MockQueryCoord_GetShardLeaders_Call {
return &MockQueryCoord_GetShardLeaders_Call{Call: _e.mock.On("GetShardLeaders", ctx, req)} return &MockQueryCoord_GetShardLeaders_Call{Call: _e.mock.On("GetShardLeaders", ctx, req)}
} }
@ -528,7 +528,7 @@ type MockQueryCoord_GetStatisticsChannel_Call struct {
} }
// GetStatisticsChannel is a helper method to define mock.On call // GetStatisticsChannel is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockQueryCoord_Expecter) GetStatisticsChannel(ctx interface{}) *MockQueryCoord_GetStatisticsChannel_Call { func (_e *MockQueryCoord_Expecter) GetStatisticsChannel(ctx interface{}) *MockQueryCoord_GetStatisticsChannel_Call {
return &MockQueryCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)} return &MockQueryCoord_GetStatisticsChannel_Call{Call: _e.mock.On("GetStatisticsChannel", ctx)}
} }
@ -574,7 +574,7 @@ type MockQueryCoord_GetTimeTickChannel_Call struct {
} }
// GetTimeTickChannel is a helper method to define mock.On call // GetTimeTickChannel is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
func (_e *MockQueryCoord_Expecter) GetTimeTickChannel(ctx interface{}) *MockQueryCoord_GetTimeTickChannel_Call { func (_e *MockQueryCoord_Expecter) GetTimeTickChannel(ctx interface{}) *MockQueryCoord_GetTimeTickChannel_Call {
return &MockQueryCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", ctx)} return &MockQueryCoord_GetTimeTickChannel_Call{Call: _e.mock.On("GetTimeTickChannel", ctx)}
} }
@ -656,8 +656,8 @@ type MockQueryCoord_ListResourceGroups_Call struct {
} }
// ListResourceGroups is a helper method to define mock.On call // ListResourceGroups is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *milvuspb.ListResourceGroupsRequest // - req *milvuspb.ListResourceGroupsRequest
func (_e *MockQueryCoord_Expecter) ListResourceGroups(ctx interface{}, req interface{}) *MockQueryCoord_ListResourceGroups_Call { func (_e *MockQueryCoord_Expecter) ListResourceGroups(ctx interface{}, req interface{}) *MockQueryCoord_ListResourceGroups_Call {
return &MockQueryCoord_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", ctx, req)} return &MockQueryCoord_ListResourceGroups_Call{Call: _e.mock.On("ListResourceGroups", ctx, req)}
} }
@ -703,8 +703,8 @@ type MockQueryCoord_LoadBalance_Call struct {
} }
// LoadBalance is a helper method to define mock.On call // LoadBalance is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.LoadBalanceRequest // - req *querypb.LoadBalanceRequest
func (_e *MockQueryCoord_Expecter) LoadBalance(ctx interface{}, req interface{}) *MockQueryCoord_LoadBalance_Call { func (_e *MockQueryCoord_Expecter) LoadBalance(ctx interface{}, req interface{}) *MockQueryCoord_LoadBalance_Call {
return &MockQueryCoord_LoadBalance_Call{Call: _e.mock.On("LoadBalance", ctx, req)} return &MockQueryCoord_LoadBalance_Call{Call: _e.mock.On("LoadBalance", ctx, req)}
} }
@ -750,8 +750,8 @@ type MockQueryCoord_LoadCollection_Call struct {
} }
// LoadCollection is a helper method to define mock.On call // LoadCollection is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.LoadCollectionRequest // - req *querypb.LoadCollectionRequest
func (_e *MockQueryCoord_Expecter) LoadCollection(ctx interface{}, req interface{}) *MockQueryCoord_LoadCollection_Call { func (_e *MockQueryCoord_Expecter) LoadCollection(ctx interface{}, req interface{}) *MockQueryCoord_LoadCollection_Call {
return &MockQueryCoord_LoadCollection_Call{Call: _e.mock.On("LoadCollection", ctx, req)} return &MockQueryCoord_LoadCollection_Call{Call: _e.mock.On("LoadCollection", ctx, req)}
} }
@ -797,8 +797,8 @@ type MockQueryCoord_LoadPartitions_Call struct {
} }
// LoadPartitions is a helper method to define mock.On call // LoadPartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.LoadPartitionsRequest // - req *querypb.LoadPartitionsRequest
func (_e *MockQueryCoord_Expecter) LoadPartitions(ctx interface{}, req interface{}) *MockQueryCoord_LoadPartitions_Call { func (_e *MockQueryCoord_Expecter) LoadPartitions(ctx interface{}, req interface{}) *MockQueryCoord_LoadPartitions_Call {
return &MockQueryCoord_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, req)} return &MockQueryCoord_LoadPartitions_Call{Call: _e.mock.On("LoadPartitions", ctx, req)}
} }
@ -880,8 +880,8 @@ type MockQueryCoord_ReleaseCollection_Call struct {
} }
// ReleaseCollection is a helper method to define mock.On call // ReleaseCollection is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.ReleaseCollectionRequest // - req *querypb.ReleaseCollectionRequest
func (_e *MockQueryCoord_Expecter) ReleaseCollection(ctx interface{}, req interface{}) *MockQueryCoord_ReleaseCollection_Call { func (_e *MockQueryCoord_Expecter) ReleaseCollection(ctx interface{}, req interface{}) *MockQueryCoord_ReleaseCollection_Call {
return &MockQueryCoord_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, req)} return &MockQueryCoord_ReleaseCollection_Call{Call: _e.mock.On("ReleaseCollection", ctx, req)}
} }
@ -927,8 +927,8 @@ type MockQueryCoord_ReleasePartitions_Call struct {
} }
// ReleasePartitions is a helper method to define mock.On call // ReleasePartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.ReleasePartitionsRequest // - req *querypb.ReleasePartitionsRequest
func (_e *MockQueryCoord_Expecter) ReleasePartitions(ctx interface{}, req interface{}) *MockQueryCoord_ReleasePartitions_Call { func (_e *MockQueryCoord_Expecter) ReleasePartitions(ctx interface{}, req interface{}) *MockQueryCoord_ReleasePartitions_Call {
return &MockQueryCoord_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, req)} return &MockQueryCoord_ReleasePartitions_Call{Call: _e.mock.On("ReleasePartitions", ctx, req)}
} }
@ -956,7 +956,7 @@ type MockQueryCoord_SetAddress_Call struct {
} }
// SetAddress is a helper method to define mock.On call // SetAddress is a helper method to define mock.On call
// - address string // - address string
func (_e *MockQueryCoord_Expecter) SetAddress(address interface{}) *MockQueryCoord_SetAddress_Call { func (_e *MockQueryCoord_Expecter) SetAddress(address interface{}) *MockQueryCoord_SetAddress_Call {
return &MockQueryCoord_SetAddress_Call{Call: _e.mock.On("SetAddress", address)} return &MockQueryCoord_SetAddress_Call{Call: _e.mock.On("SetAddress", address)}
} }
@ -993,7 +993,7 @@ type MockQueryCoord_SetDataCoord_Call struct {
} }
// SetDataCoord is a helper method to define mock.On call // SetDataCoord is a helper method to define mock.On call
// - dataCoord DataCoord // - dataCoord DataCoord
func (_e *MockQueryCoord_Expecter) SetDataCoord(dataCoord interface{}) *MockQueryCoord_SetDataCoord_Call { func (_e *MockQueryCoord_Expecter) SetDataCoord(dataCoord interface{}) *MockQueryCoord_SetDataCoord_Call {
return &MockQueryCoord_SetDataCoord_Call{Call: _e.mock.On("SetDataCoord", dataCoord)} return &MockQueryCoord_SetDataCoord_Call{Call: _e.mock.On("SetDataCoord", dataCoord)}
} }
@ -1021,7 +1021,7 @@ type MockQueryCoord_SetEtcdClient_Call struct {
} }
// SetEtcdClient is a helper method to define mock.On call // SetEtcdClient is a helper method to define mock.On call
// - etcdClient *clientv3.Client // - etcdClient *clientv3.Client
func (_e *MockQueryCoord_Expecter) SetEtcdClient(etcdClient interface{}) *MockQueryCoord_SetEtcdClient_Call { func (_e *MockQueryCoord_Expecter) SetEtcdClient(etcdClient interface{}) *MockQueryCoord_SetEtcdClient_Call {
return &MockQueryCoord_SetEtcdClient_Call{Call: _e.mock.On("SetEtcdClient", etcdClient)} return &MockQueryCoord_SetEtcdClient_Call{Call: _e.mock.On("SetEtcdClient", etcdClient)}
} }
@ -1049,7 +1049,7 @@ type MockQueryCoord_SetQueryNodeCreator_Call struct {
} }
// SetQueryNodeCreator is a helper method to define mock.On call // SetQueryNodeCreator is a helper method to define mock.On call
// - _a0 func(context.Context , string)(QueryNode , error) // - _a0 func(context.Context , string)(QueryNode , error)
func (_e *MockQueryCoord_Expecter) SetQueryNodeCreator(_a0 interface{}) *MockQueryCoord_SetQueryNodeCreator_Call { func (_e *MockQueryCoord_Expecter) SetQueryNodeCreator(_a0 interface{}) *MockQueryCoord_SetQueryNodeCreator_Call {
return &MockQueryCoord_SetQueryNodeCreator_Call{Call: _e.mock.On("SetQueryNodeCreator", _a0)} return &MockQueryCoord_SetQueryNodeCreator_Call{Call: _e.mock.On("SetQueryNodeCreator", _a0)}
} }
@ -1086,7 +1086,7 @@ type MockQueryCoord_SetRootCoord_Call struct {
} }
// SetRootCoord is a helper method to define mock.On call // SetRootCoord is a helper method to define mock.On call
// - rootCoord RootCoord // - rootCoord RootCoord
func (_e *MockQueryCoord_Expecter) SetRootCoord(rootCoord interface{}) *MockQueryCoord_SetRootCoord_Call { func (_e *MockQueryCoord_Expecter) SetRootCoord(rootCoord interface{}) *MockQueryCoord_SetRootCoord_Call {
return &MockQueryCoord_SetRootCoord_Call{Call: _e.mock.On("SetRootCoord", rootCoord)} return &MockQueryCoord_SetRootCoord_Call{Call: _e.mock.On("SetRootCoord", rootCoord)}
} }
@ -1132,8 +1132,8 @@ type MockQueryCoord_ShowCollections_Call struct {
} }
// ShowCollections is a helper method to define mock.On call // ShowCollections is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.ShowCollectionsRequest // - req *querypb.ShowCollectionsRequest
func (_e *MockQueryCoord_Expecter) ShowCollections(ctx interface{}, req interface{}) *MockQueryCoord_ShowCollections_Call { func (_e *MockQueryCoord_Expecter) ShowCollections(ctx interface{}, req interface{}) *MockQueryCoord_ShowCollections_Call {
return &MockQueryCoord_ShowCollections_Call{Call: _e.mock.On("ShowCollections", ctx, req)} return &MockQueryCoord_ShowCollections_Call{Call: _e.mock.On("ShowCollections", ctx, req)}
} }
@ -1179,8 +1179,8 @@ type MockQueryCoord_ShowConfigurations_Call struct {
} }
// ShowConfigurations is a helper method to define mock.On call // ShowConfigurations is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *internalpb.ShowConfigurationsRequest // - req *internalpb.ShowConfigurationsRequest
func (_e *MockQueryCoord_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *MockQueryCoord_ShowConfigurations_Call { func (_e *MockQueryCoord_Expecter) ShowConfigurations(ctx interface{}, req interface{}) *MockQueryCoord_ShowConfigurations_Call {
return &MockQueryCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)} return &MockQueryCoord_ShowConfigurations_Call{Call: _e.mock.On("ShowConfigurations", ctx, req)}
} }
@ -1226,8 +1226,8 @@ type MockQueryCoord_ShowPartitions_Call struct {
} }
// ShowPartitions is a helper method to define mock.On call // ShowPartitions is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.ShowPartitionsRequest // - req *querypb.ShowPartitionsRequest
func (_e *MockQueryCoord_Expecter) ShowPartitions(ctx interface{}, req interface{}) *MockQueryCoord_ShowPartitions_Call { func (_e *MockQueryCoord_Expecter) ShowPartitions(ctx interface{}, req interface{}) *MockQueryCoord_ShowPartitions_Call {
return &MockQueryCoord_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, req)} return &MockQueryCoord_ShowPartitions_Call{Call: _e.mock.On("ShowPartitions", ctx, req)}
} }
@ -1316,6 +1316,53 @@ func (_c *MockQueryCoord_Stop_Call) Return(_a0 error) *MockQueryCoord_Stop_Call
return _c return _c
} }
// SyncNewCreatedPartition provides a mock function with given fields: ctx, req
func (_m *MockQueryCoord) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req)
var r0 *commonpb.Status
if rf, ok := ret.Get(0).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest) *commonpb.Status); ok {
r0 = rf(ctx, req)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*commonpb.Status)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *querypb.SyncNewCreatedPartitionRequest) error); ok {
r1 = rf(ctx, req)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockQueryCoord_SyncNewCreatedPartition_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SyncNewCreatedPartition'
type MockQueryCoord_SyncNewCreatedPartition_Call struct {
*mock.Call
}
// SyncNewCreatedPartition is a helper method to define mock.On call
// - ctx context.Context
// - req *querypb.SyncNewCreatedPartitionRequest
func (_e *MockQueryCoord_Expecter) SyncNewCreatedPartition(ctx interface{}, req interface{}) *MockQueryCoord_SyncNewCreatedPartition_Call {
return &MockQueryCoord_SyncNewCreatedPartition_Call{Call: _e.mock.On("SyncNewCreatedPartition", ctx, req)}
}
func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) Run(run func(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest)) *MockQueryCoord_SyncNewCreatedPartition_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(*querypb.SyncNewCreatedPartitionRequest))
})
return _c
}
func (_c *MockQueryCoord_SyncNewCreatedPartition_Call) Return(_a0 *commonpb.Status, _a1 error) *MockQueryCoord_SyncNewCreatedPartition_Call {
_c.Call.Return(_a0, _a1)
return _c
}
// TransferNode provides a mock function with given fields: ctx, req // TransferNode provides a mock function with given fields: ctx, req
func (_m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) { func (_m *MockQueryCoord) TransferNode(ctx context.Context, req *milvuspb.TransferNodeRequest) (*commonpb.Status, error) {
ret := _m.Called(ctx, req) ret := _m.Called(ctx, req)
@ -1345,8 +1392,8 @@ type MockQueryCoord_TransferNode_Call struct {
} }
// TransferNode is a helper method to define mock.On call // TransferNode is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *milvuspb.TransferNodeRequest // - req *milvuspb.TransferNodeRequest
func (_e *MockQueryCoord_Expecter) TransferNode(ctx interface{}, req interface{}) *MockQueryCoord_TransferNode_Call { func (_e *MockQueryCoord_Expecter) TransferNode(ctx interface{}, req interface{}) *MockQueryCoord_TransferNode_Call {
return &MockQueryCoord_TransferNode_Call{Call: _e.mock.On("TransferNode", ctx, req)} return &MockQueryCoord_TransferNode_Call{Call: _e.mock.On("TransferNode", ctx, req)}
} }
@ -1392,8 +1439,8 @@ type MockQueryCoord_TransferReplica_Call struct {
} }
// TransferReplica is a helper method to define mock.On call // TransferReplica is a helper method to define mock.On call
// - ctx context.Context // - ctx context.Context
// - req *querypb.TransferReplicaRequest // - req *querypb.TransferReplicaRequest
func (_e *MockQueryCoord_Expecter) TransferReplica(ctx interface{}, req interface{}) *MockQueryCoord_TransferReplica_Call { func (_e *MockQueryCoord_Expecter) TransferReplica(ctx interface{}, req interface{}) *MockQueryCoord_TransferReplica_Call {
return &MockQueryCoord_TransferReplica_Call{Call: _e.mock.On("TransferReplica", ctx, req)} return &MockQueryCoord_TransferReplica_Call{Call: _e.mock.On("TransferReplica", ctx, req)}
} }
@ -1421,7 +1468,7 @@ type MockQueryCoord_UpdateStateCode_Call struct {
} }
// UpdateStateCode is a helper method to define mock.On call // UpdateStateCode is a helper method to define mock.On call
// - stateCode commonpb.StateCode // - stateCode commonpb.StateCode
func (_e *MockQueryCoord_Expecter) UpdateStateCode(stateCode interface{}) *MockQueryCoord_UpdateStateCode_Call { func (_e *MockQueryCoord_Expecter) UpdateStateCode(stateCode interface{}) *MockQueryCoord_UpdateStateCode_Call {
return &MockQueryCoord_UpdateStateCode_Call{Call: _e.mock.On("UpdateStateCode", stateCode)} return &MockQueryCoord_UpdateStateCode_Call{Call: _e.mock.On("UpdateStateCode", stateCode)}
} }

View File

@ -1330,6 +1330,7 @@ type QueryNode interface {
// All the sealed segments are loaded. // All the sealed segments are loaded.
LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error)
ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error)
LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error)
ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)
ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error)
GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)
@ -1374,6 +1375,7 @@ type QueryCoord interface {
ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error)
GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error)
GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error)
SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest) (*commonpb.Status, error)
LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error)
ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error)

View File

@ -82,6 +82,10 @@ func (m *GrpcQueryCoordClient) GetSegmentInfo(ctx context.Context, in *querypb.G
return &querypb.GetSegmentInfoResponse{}, m.Err return &querypb.GetSegmentInfoResponse{}, m.Err
} }
func (m *GrpcQueryCoordClient) SyncNewCreatedPartition(ctx context.Context, req *querypb.SyncNewCreatedPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryCoordClient) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { func (m *GrpcQueryCoordClient) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err return &commonpb.Status{}, m.Err
} }

View File

@ -61,6 +61,10 @@ func (m *GrpcQueryNodeClient) ReleaseCollection(ctx context.Context, in *querypb
return &commonpb.Status{}, m.Err return &commonpb.Status{}, m.Err
} }
func (m *GrpcQueryNodeClient) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err
}
func (m *GrpcQueryNodeClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { func (m *GrpcQueryNodeClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) {
return &commonpb.Status{}, m.Err return &commonpb.Status{}, m.Err
} }

View File

@ -77,6 +77,10 @@ func (q QueryNodeClient) ReleaseCollection(ctx context.Context, req *querypb.Rel
return q.grpcClient.ReleaseCollection(ctx, req) return q.grpcClient.ReleaseCollection(ctx, req)
} }
func (q QueryNodeClient) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
return q.grpcClient.LoadPartitions(ctx, req)
}
func (q QueryNodeClient) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { func (q QueryNodeClient) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return q.grpcClient.ReleasePartitions(ctx, req) return q.grpcClient.ReleasePartitions(ctx, req)
} }

View File

@ -1116,9 +1116,7 @@ class TestCollectionOperation(TestcaseBase):
partition_w1.insert(cf.gen_default_list_data()) partition_w1.insert(cf.gen_default_list_data())
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load() collection_w.load()
error = {ct.err_code: 5, ct.err_msg: f'load the partition after load collection is not supported'} partition_w1.load()
partition_w1.load(check_task=CheckTasks.err_res,
check_items=error)
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_load_collection_release_partition(self): def test_load_collection_release_partition(self):
@ -1133,9 +1131,7 @@ class TestCollectionOperation(TestcaseBase):
partition_w1.insert(cf.gen_default_list_data()) partition_w1.insert(cf.gen_default_list_data())
collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index) collection_w.create_index(ct.default_float_vec_field_name, index_params=ct.default_flat_index)
collection_w.load() collection_w.load()
error = {ct.err_code: 1, ct.err_msg: f'releasing the partition after load collection is not supported'} partition_w1.release()
partition_w1.release(check_task=CheckTasks.err_res,
check_items=error)
@pytest.mark.tags(CaseLabel.L2) @pytest.mark.tags(CaseLabel.L2)
def test_load_collection_after_release_collection(self): def test_load_collection_after_release_collection(self):