Implement the GetLoadState api (#21257)

Signed-off-by: SimFG <bang.fu@zilliz.com>

Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
SimFG 2022-12-16 14:39:24 +08:00 committed by GitHub
parent 44cc62b81d
commit 63cd4132a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 370 additions and 93 deletions

View File

@ -18,7 +18,6 @@ package proxy
import (
"context"
"errors"
"fmt"
"os"
"strconv"
@ -1424,61 +1423,6 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar
return spt.result, nil
}
func (node *Proxy) getCollectionProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest, collectionID int64) (int64, error) {
resp, err := node.queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
Base: commonpbutil.UpdateMsgBase(
request.Base,
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
),
CollectionIDs: []int64{collectionID},
})
if err != nil {
return 0, err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return 0, errors.New(resp.Status.Reason)
}
if len(resp.InMemoryPercentages) == 0 {
return 0, errors.New("fail to show collections from the querycoord, no data")
}
return resp.InMemoryPercentages[0], nil
}
func (node *Proxy) getPartitionProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest, collectionID int64) (int64, error) {
IDs2Names := make(map[int64]string)
partitionIDs := make([]int64, 0)
for _, partitionName := range request.PartitionNames {
partitionID, err := globalMetaCache.GetPartitionID(ctx, request.CollectionName, partitionName)
if err != nil {
return 0, err
}
IDs2Names[partitionID] = partitionName
partitionIDs = append(partitionIDs, partitionID)
}
resp, err := node.queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: commonpbutil.UpdateMsgBase(
request.Base,
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
),
CollectionID: collectionID,
PartitionIDs: partitionIDs,
})
if err != nil {
return 0, err
}
if len(resp.InMemoryPercentages) != len(partitionIDs) {
return 0, errors.New("fail to show partitions from the querycoord, invalid data num")
}
var progress int64
for _, p := range resp.InMemoryPercentages {
progress += p
}
progress /= int64(len(partitionIDs))
return progress, nil
}
func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
if !node.checkHealthy() {
return &milvuspb.GetLoadingProgressResponse{Status: unhealthyStatus()}, nil
@ -1496,9 +1440,9 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
getErrResponse := func(err error) *milvuspb.GetLoadingProgressResponse {
log.Warn("fail to get loading progress",
zap.Error(err),
zap.String("collection_name", request.CollectionName),
zap.Strings("partition_name", request.PartitionNames))
zap.Strings("partition_name", request.PartitionNames),
zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
return &milvuspb.GetLoadingProgressResponse{
Status: &commonpb.Status{
@ -1514,6 +1458,13 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
if err != nil {
return getErrResponse(err), nil
}
if statesResp, err := node.queryCoord.GetComponentStates(ctx); err != nil {
return getErrResponse(err), nil
} else if statesResp.State == nil || statesResp.State.StateCode != commonpb.StateCode_Healthy {
return getErrResponse(fmt.Errorf("the querycoord server isn't healthy, state: %v", statesResp.State)), nil
}
msgBase := commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_SystemInfo),
commonpbutil.WithMsgID(0),
@ -1529,11 +1480,12 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
var progress int64
if len(request.GetPartitionNames()) == 0 {
if progress, err = node.getCollectionProgress(ctx, request, collectionID); err != nil {
if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil {
return getErrResponse(err), nil
}
} else {
if progress, err = node.getPartitionProgress(ctx, request, collectionID); err != nil {
if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
return getErrResponse(err), nil
}
}
@ -1552,7 +1504,95 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
}
func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) {
return nil, nil
if !node.checkHealthy() {
return &milvuspb.GetLoadStateResponse{Status: unhealthyStatus()}, nil
}
method := "GetLoadState"
tr := timerecord.NewTimeRecorder(method)
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-GetLoadState")
defer sp.Finish()
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
log := log.Ctx(ctx)
log.Debug(
rpcReceived(method),
zap.Any("request", request))
getErrResponse := func(err error) *milvuspb.GetLoadStateResponse {
log.Warn("fail to get load state",
zap.String("collection_name", request.CollectionName),
zap.Strings("partition_name", request.PartitionNames),
zap.Error(err))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
return &milvuspb.GetLoadStateResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}
}
if err := validateCollectionName(request.CollectionName); err != nil {
return getErrResponse(err), nil
}
if statesResp, err := node.queryCoord.GetComponentStates(ctx); err != nil {
return getErrResponse(err), nil
} else if statesResp.State == nil || statesResp.State.StateCode != commonpb.StateCode_Healthy {
return getErrResponse(fmt.Errorf("the querycoord server isn't healthy, state: %v", statesResp.State)), nil
}
successResponse := &milvuspb.GetLoadStateResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
}
defer func() {
log.Debug(
rpcDone(method),
zap.Any("request", request))
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
}()
collectionID, err := globalMetaCache.GetCollectionID(ctx, request.CollectionName)
if err != nil {
successResponse.State = commonpb.LoadState_LoadStateNotExist
return successResponse, nil
}
msgBase := commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_SystemInfo),
commonpbutil.WithMsgID(0),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
)
if request.Base == nil {
request.Base = msgBase
} else {
request.Base.MsgID = msgBase.MsgID
request.Base.Timestamp = msgBase.Timestamp
request.Base.SourceID = msgBase.SourceID
}
var progress int64
if len(request.GetPartitionNames()) == 0 {
if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil {
successResponse.State = commonpb.LoadState_LoadStateNotLoad
return successResponse, nil
}
} else {
if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
successResponse.State = commonpb.LoadState_LoadStateNotLoad
return successResponse, nil
}
}
if progress >= 100 {
successResponse.State = commonpb.LoadState_LoadStateLoaded
} else {
successResponse.State = commonpb.LoadState_LoadStateLoading
}
return successResponse, nil
}
// CreateIndex create index for collection.

View File

@ -30,15 +30,33 @@ import (
"github.com/golang/protobuf/proto"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/datacoord"
"github.com/milvus-io/milvus/internal/datanode"
grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord"
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord"
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord"
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
"github.com/milvus-io/milvus/internal/indexcoord"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
querycoord "github.com/milvus-io/milvus/internal/querycoordv2"
"github.com/milvus-io/milvus/internal/querynode"
"github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/crypto"
@ -53,32 +71,11 @@ import (
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord"
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord"
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord"
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
"github.com/milvus-io/milvus/internal/datacoord"
"github.com/milvus-io/milvus/internal/datanode"
"github.com/milvus-io/milvus/internal/indexcoord"
"github.com/milvus-io/milvus/internal/indexnode"
querycoord "github.com/milvus-io/milvus/internal/querycoordv2"
"github.com/milvus-io/milvus/internal/querynode"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
)
const (
@ -1050,6 +1047,16 @@ func TestProxy(t *testing.T) {
// default partition
assert.Equal(t, 2, len(resp.PartitionNames))
{
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
CollectionName: collectionName,
PartitionNames: resp.PartitionNames,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
}
// non-exist collection -> fail
resp, err = proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{
Base: nil,
@ -1201,6 +1208,15 @@ func TestProxy(t *testing.T) {
wg.Add(1)
t.Run("load collection", func(t *testing.T) {
defer wg.Done()
{
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
CollectionName: collectionName,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
}
resp, err := proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
Base: nil,
DbName: dbName,
@ -1307,6 +1323,15 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode)
assert.Equal(t, int64(0), progressResp.Progress)
}
{
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
CollectionName: otherCollectionName,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateNotExist, stateResp.State)
}
})
wg.Add(1)
@ -2516,6 +2541,14 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("GetLoadState fail, unhealthy", func(t *testing.T) {
defer wg.Done()
resp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("CreateIndex fail, unhealthy", func(t *testing.T) {
defer wg.Done()
@ -3979,3 +4012,135 @@ func TestProxy_ListImportTasks(t *testing.T) {
func TestProxy_GetStatistics(t *testing.T) {
}
func TestProxy_GetLoadState(t *testing.T) {
originCache := globalMetaCache
m := newMockCache()
m.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
return 1, nil
})
m.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) {
return 2, nil
})
globalMetaCache = m
defer func() {
globalMetaCache = originCache
}()
{
q := NewQueryCoordMock()
q.state.Store(commonpb.StateCode_Abnormal)
proxy := &Proxy{queryCoord: q}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode)
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return nil, errors.New("test")
}), SetQueryCoordShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
return nil, errors.New("test")
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", PartitionNames: []string{"p1"}})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
assert.Equal(t, int64(0), progressResp.Progress)
progressResp, err = proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo", PartitionNames: []string{"p1"}})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
assert.Equal(t, int64(0), progressResp.Progress)
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: request.CollectionIDs,
InMemoryPercentages: []int64{},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: request.CollectionIDs,
InMemoryPercentages: []int64{100},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", Base: &commonpb.MsgBase{}})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateLoaded, stateResp.State)
stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: ""})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode)
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode)
assert.Equal(t, int64(100), progressResp.Progress)
}
{
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: request.CollectionIDs,
InMemoryPercentages: []int64{50},
}, nil
}))
q.state.Store(commonpb.StateCode_Healthy)
proxy := &Proxy{queryCoord: q}
proxy.stateCode.Store(commonpb.StateCode_Healthy)
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
assert.Equal(t, commonpb.LoadState_LoadStateLoading, stateResp.State)
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode)
assert.Equal(t, int64(50), progressResp.Progress)
}
}

View File

@ -24,6 +24,8 @@ import (
"strings"
"time"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
@ -941,3 +943,73 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
return ids, nil
}
func getCollectionProgress(ctx context.Context, queryCoord types.QueryCoord,
msgBase *commonpb.MsgBase, collectionID int64) (int64, error) {
resp, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
Base: commonpbutil.UpdateMsgBase(
msgBase,
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
),
CollectionIDs: []int64{collectionID},
})
if err != nil {
log.Warn("fail to show collections", zap.Int64("collection_id", collectionID), zap.Error(err))
return 0, err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
log.Warn("fail to show collections", zap.Int64("collection_id", collectionID),
zap.String("reason", resp.Status.Reason))
return 0, errors.New(resp.Status.Reason)
}
if len(resp.InMemoryPercentages) == 0 {
errMsg := "fail to show collections from the querycoord, no data"
log.Warn(errMsg, zap.Int64("collection_id", collectionID))
return 0, errors.New(errMsg)
}
return resp.InMemoryPercentages[0], nil
}
func getPartitionProgress(ctx context.Context, queryCoord types.QueryCoord,
msgBase *commonpb.MsgBase, partitionNames []string, collectionName string, collectionID int64) (int64, error) {
IDs2Names := make(map[int64]string)
partitionIDs := make([]int64, 0)
for _, partitionName := range partitionNames {
partitionID, err := globalMetaCache.GetPartitionID(ctx, collectionName, partitionName)
if err != nil {
return 0, err
}
IDs2Names[partitionID] = partitionName
partitionIDs = append(partitionIDs, partitionID)
}
resp, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: commonpbutil.UpdateMsgBase(
msgBase,
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
),
CollectionID: collectionID,
PartitionIDs: partitionIDs,
})
if err != nil {
log.Warn("fail to show partitions", zap.Int64("collection_id", collectionID),
zap.String("collection_name", collectionName),
zap.Strings("partition_names", partitionNames),
zap.Error(err))
return 0, err
}
if len(resp.InMemoryPercentages) != len(partitionIDs) {
errMsg := "fail to show partitions from the querycoord, invalid data num"
log.Warn(errMsg, zap.Int64("collection_id", collectionID),
zap.String("collection_name", collectionName),
zap.Strings("partition_names", partitionNames))
return 0, errors.New(errMsg)
}
var progress int64
for _, p := range resp.InMemoryPercentages {
progress += p
}
progress /= int64(len(partitionIDs))
return progress, nil
}