diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index b9b085e22f..09abe99f45 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -18,7 +18,6 @@ package proxy import ( "context" - "errors" "fmt" "os" "strconv" @@ -1424,61 +1423,6 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar return spt.result, nil } -func (node *Proxy) getCollectionProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest, collectionID int64) (int64, error) { - resp, err := node.queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ - Base: commonpbutil.UpdateMsgBase( - request.Base, - commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), - ), - CollectionIDs: []int64{collectionID}, - }) - if err != nil { - return 0, err - } - - if resp.Status.ErrorCode != commonpb.ErrorCode_Success { - return 0, errors.New(resp.Status.Reason) - } - - if len(resp.InMemoryPercentages) == 0 { - return 0, errors.New("fail to show collections from the querycoord, no data") - } - return resp.InMemoryPercentages[0], nil -} - -func (node *Proxy) getPartitionProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest, collectionID int64) (int64, error) { - IDs2Names := make(map[int64]string) - partitionIDs := make([]int64, 0) - for _, partitionName := range request.PartitionNames { - partitionID, err := globalMetaCache.GetPartitionID(ctx, request.CollectionName, partitionName) - if err != nil { - return 0, err - } - IDs2Names[partitionID] = partitionName - partitionIDs = append(partitionIDs, partitionID) - } - resp, err := node.queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ - Base: commonpbutil.UpdateMsgBase( - request.Base, - commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), - ), - CollectionID: collectionID, - PartitionIDs: partitionIDs, - }) - if err != nil { - return 0, err - } - if len(resp.InMemoryPercentages) != len(partitionIDs) { - return 0, errors.New("fail to show partitions from the querycoord, invalid data num") - } - var progress int64 - for _, p := range resp.InMemoryPercentages { - progress += p - } - progress /= int64(len(partitionIDs)) - return progress, nil -} - func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) { if !node.checkHealthy() { return &milvuspb.GetLoadingProgressResponse{Status: unhealthyStatus()}, nil @@ -1496,9 +1440,9 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get getErrResponse := func(err error) *milvuspb.GetLoadingProgressResponse { log.Warn("fail to get loading progress", - zap.Error(err), zap.String("collection_name", request.CollectionName), - zap.Strings("partition_name", request.PartitionNames)) + zap.Strings("partition_name", request.PartitionNames), + zap.Error(err)) metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() return &milvuspb.GetLoadingProgressResponse{ Status: &commonpb.Status{ @@ -1514,6 +1458,13 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get if err != nil { return getErrResponse(err), nil } + + if statesResp, err := node.queryCoord.GetComponentStates(ctx); err != nil { + return getErrResponse(err), nil + } else if statesResp.State == nil || statesResp.State.StateCode != commonpb.StateCode_Healthy { + return getErrResponse(fmt.Errorf("the querycoord server isn't healthy, state: %v", statesResp.State)), nil + } + msgBase := commonpbutil.NewMsgBase( commonpbutil.WithMsgType(commonpb.MsgType_SystemInfo), commonpbutil.WithMsgID(0), @@ -1529,11 +1480,12 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get var progress int64 if len(request.GetPartitionNames()) == 0 { - if progress, err = node.getCollectionProgress(ctx, request, collectionID); err != nil { + if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil { return getErrResponse(err), nil } } else { - if progress, err = node.getPartitionProgress(ctx, request, collectionID); err != nil { + if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(), + request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil { return getErrResponse(err), nil } } @@ -1552,7 +1504,95 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get } func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) { - return nil, nil + if !node.checkHealthy() { + return &milvuspb.GetLoadStateResponse{Status: unhealthyStatus()}, nil + } + method := "GetLoadState" + tr := timerecord.NewTimeRecorder(method) + sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-GetLoadState") + defer sp.Finish() + metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc() + log := log.Ctx(ctx) + + log.Debug( + rpcReceived(method), + zap.Any("request", request)) + + getErrResponse := func(err error) *milvuspb.GetLoadStateResponse { + log.Warn("fail to get load state", + zap.String("collection_name", request.CollectionName), + zap.Strings("partition_name", request.PartitionNames), + zap.Error(err)) + metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc() + return &milvuspb.GetLoadStateResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: err.Error(), + }, + } + } + + if err := validateCollectionName(request.CollectionName); err != nil { + return getErrResponse(err), nil + } + + if statesResp, err := node.queryCoord.GetComponentStates(ctx); err != nil { + return getErrResponse(err), nil + } else if statesResp.State == nil || statesResp.State.StateCode != commonpb.StateCode_Healthy { + return getErrResponse(fmt.Errorf("the querycoord server isn't healthy, state: %v", statesResp.State)), nil + } + + successResponse := &milvuspb.GetLoadStateResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + } + defer func() { + log.Debug( + rpcDone(method), + zap.Any("request", request)) + metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc() + metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds())) + }() + + collectionID, err := globalMetaCache.GetCollectionID(ctx, request.CollectionName) + if err != nil { + successResponse.State = commonpb.LoadState_LoadStateNotExist + return successResponse, nil + } + + msgBase := commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_SystemInfo), + commonpbutil.WithMsgID(0), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ) + if request.Base == nil { + request.Base = msgBase + } else { + request.Base.MsgID = msgBase.MsgID + request.Base.Timestamp = msgBase.Timestamp + request.Base.SourceID = msgBase.SourceID + } + + var progress int64 + if len(request.GetPartitionNames()) == 0 { + if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil { + successResponse.State = commonpb.LoadState_LoadStateNotLoad + return successResponse, nil + } + } else { + if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(), + request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil { + successResponse.State = commonpb.LoadState_LoadStateNotLoad + return successResponse, nil + } + } + if progress >= 100 { + successResponse.State = commonpb.LoadState_LoadStateLoaded + } else { + successResponse.State = commonpb.LoadState_LoadStateLoading + } + return successResponse, nil } // CreateIndex create index for collection. diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 5092bc949e..0f6a58506a 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -30,15 +30,33 @@ import ( "github.com/golang/protobuf/proto" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" - + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" "github.com/milvus-io/milvus/internal/common" + "github.com/milvus-io/milvus/internal/datacoord" + "github.com/milvus-io/milvus/internal/datanode" + grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord" + grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client" + grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode" + grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord" + grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client" + grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode" + grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" + grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client" + grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode" + grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" + rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" + "github.com/milvus-io/milvus/internal/indexcoord" + "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/metrics" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + querycoord "github.com/milvus-io/milvus/internal/querycoordv2" + "github.com/milvus-io/milvus/internal/querynode" "github.com/milvus-io/milvus/internal/rootcoord" "github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util/crypto" @@ -53,32 +71,11 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" - - "github.com/milvus-io/milvus-proto/go-api/commonpb" - "github.com/milvus-io/milvus-proto/go-api/milvuspb" - "github.com/milvus-io/milvus-proto/go-api/schemapb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" - - grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord" - grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client" - grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode" - grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord" - grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client" - grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode" - grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" - grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client" - grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode" - grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" - rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client" - - "github.com/milvus-io/milvus/internal/datacoord" - "github.com/milvus-io/milvus/internal/datanode" - "github.com/milvus-io/milvus/internal/indexcoord" - "github.com/milvus-io/milvus/internal/indexnode" - querycoord "github.com/milvus-io/milvus/internal/querycoordv2" - "github.com/milvus-io/milvus/internal/querynode" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/keepalive" ) const ( @@ -1050,6 +1047,16 @@ func TestProxy(t *testing.T) { // default partition assert.Equal(t, 2, len(resp.PartitionNames)) + { + stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ + CollectionName: collectionName, + PartitionNames: resp.PartitionNames, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) + } + // non-exist collection -> fail resp, err = proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{ Base: nil, @@ -1201,6 +1208,15 @@ func TestProxy(t *testing.T) { wg.Add(1) t.Run("load collection", func(t *testing.T) { defer wg.Done() + { + stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ + CollectionName: collectionName, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) + } + resp, err := proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ Base: nil, DbName: dbName, @@ -1307,6 +1323,15 @@ func TestProxy(t *testing.T) { assert.NotEqual(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) assert.Equal(t, int64(0), progressResp.Progress) } + + { + stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{ + CollectionName: otherCollectionName, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateNotExist, stateResp.State) + } }) wg.Add(1) @@ -2516,6 +2541,14 @@ func TestProxy(t *testing.T) { assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) }) + wg.Add(1) + t.Run("GetLoadState fail, unhealthy", func(t *testing.T) { + defer wg.Done() + resp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{}) + assert.NoError(t, err) + assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) + }) + wg.Add(1) t.Run("CreateIndex fail, unhealthy", func(t *testing.T) { defer wg.Done() @@ -3979,3 +4012,135 @@ func TestProxy_ListImportTasks(t *testing.T) { func TestProxy_GetStatistics(t *testing.T) { } + +func TestProxy_GetLoadState(t *testing.T) { + originCache := globalMetaCache + m := newMockCache() + m.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) { + return 1, nil + }) + m.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) { + return 2, nil + }) + globalMetaCache = m + defer func() { + globalMetaCache = originCache + }() + + { + q := NewQueryCoordMock() + q.state.Store(commonpb.StateCode_Abnormal) + proxy := &Proxy{queryCoord: q} + proxy.stateCode.Store(commonpb.StateCode_Healthy) + stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode) + + progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) + } + + { + q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return nil, errors.New("test") + }), SetQueryCoordShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + return nil, errors.New("test") + })) + q.state.Store(commonpb.StateCode_Healthy) + proxy := &Proxy{queryCoord: q} + proxy.stateCode.Store(commonpb.StateCode_Healthy) + + stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) + + stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", PartitionNames: []string{"p1"}}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) + + progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) + assert.Equal(t, int64(0), progressResp.Progress) + + progressResp, err = proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo", PartitionNames: []string{"p1"}}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) + assert.Equal(t, int64(0), progressResp.Progress) + } + + { + q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + CollectionIDs: request.CollectionIDs, + InMemoryPercentages: []int64{}, + }, nil + })) + q.state.Store(commonpb.StateCode_Healthy) + proxy := &Proxy{queryCoord: q} + proxy.stateCode.Store(commonpb.StateCode_Healthy) + + stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State) + + progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode) + } + + { + q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + CollectionIDs: request.CollectionIDs, + InMemoryPercentages: []int64{100}, + }, nil + })) + q.state.Store(commonpb.StateCode_Healthy) + proxy := &Proxy{queryCoord: q} + proxy.stateCode.Store(commonpb.StateCode_Healthy) + + stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", Base: &commonpb.MsgBase{}}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateLoaded, stateResp.State) + + stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: ""}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode) + + progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) + assert.Equal(t, int64(100), progressResp.Progress) + } + + { + q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + CollectionIDs: request.CollectionIDs, + InMemoryPercentages: []int64{50}, + }, nil + })) + q.state.Store(commonpb.StateCode_Healthy) + proxy := &Proxy{queryCoord: q} + proxy.stateCode.Store(commonpb.StateCode_Healthy) + + stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode) + assert.Equal(t, commonpb.LoadState_LoadStateLoading, stateResp.State) + + progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"}) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode) + assert.Equal(t, int64(50), progressResp.Progress) + } +} diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 327d912429..2b0da419c1 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -24,6 +24,8 @@ import ( "strings" "time" + "github.com/milvus-io/milvus/internal/util/commonpbutil" + "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" @@ -941,3 +943,73 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre return ids, nil } + +func getCollectionProgress(ctx context.Context, queryCoord types.QueryCoord, + msgBase *commonpb.MsgBase, collectionID int64) (int64, error) { + resp, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{ + Base: commonpbutil.UpdateMsgBase( + msgBase, + commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection), + ), + CollectionIDs: []int64{collectionID}, + }) + if err != nil { + log.Warn("fail to show collections", zap.Int64("collection_id", collectionID), zap.Error(err)) + return 0, err + } + + if resp.Status.ErrorCode != commonpb.ErrorCode_Success { + log.Warn("fail to show collections", zap.Int64("collection_id", collectionID), + zap.String("reason", resp.Status.Reason)) + return 0, errors.New(resp.Status.Reason) + } + + if len(resp.InMemoryPercentages) == 0 { + errMsg := "fail to show collections from the querycoord, no data" + log.Warn(errMsg, zap.Int64("collection_id", collectionID)) + return 0, errors.New(errMsg) + } + return resp.InMemoryPercentages[0], nil +} + +func getPartitionProgress(ctx context.Context, queryCoord types.QueryCoord, + msgBase *commonpb.MsgBase, partitionNames []string, collectionName string, collectionID int64) (int64, error) { + IDs2Names := make(map[int64]string) + partitionIDs := make([]int64, 0) + for _, partitionName := range partitionNames { + partitionID, err := globalMetaCache.GetPartitionID(ctx, collectionName, partitionName) + if err != nil { + return 0, err + } + IDs2Names[partitionID] = partitionName + partitionIDs = append(partitionIDs, partitionID) + } + resp, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{ + Base: commonpbutil.UpdateMsgBase( + msgBase, + commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions), + ), + CollectionID: collectionID, + PartitionIDs: partitionIDs, + }) + if err != nil { + log.Warn("fail to show partitions", zap.Int64("collection_id", collectionID), + zap.String("collection_name", collectionName), + zap.Strings("partition_names", partitionNames), + zap.Error(err)) + return 0, err + } + if len(resp.InMemoryPercentages) != len(partitionIDs) { + errMsg := "fail to show partitions from the querycoord, invalid data num" + log.Warn(errMsg, zap.Int64("collection_id", collectionID), + zap.String("collection_name", collectionName), + zap.Strings("partition_names", partitionNames)) + return 0, errors.New(errMsg) + } + var progress int64 + for _, p := range resp.InMemoryPercentages { + progress += p + } + progress /= int64(len(partitionIDs)) + return progress, nil +}