mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Implement the GetLoadState api (#21257)
Signed-off-by: SimFG <bang.fu@zilliz.com> Signed-off-by: SimFG <bang.fu@zilliz.com>
This commit is contained in:
parent
44cc62b81d
commit
63cd4132a6
@ -18,7 +18,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
@ -1424,61 +1423,6 @@ func (node *Proxy) ShowPartitions(ctx context.Context, request *milvuspb.ShowPar
|
||||
return spt.result, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) getCollectionProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest, collectionID int64) (int64, error) {
|
||||
resp, err := node.queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
|
||||
Base: commonpbutil.UpdateMsgBase(
|
||||
request.Base,
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
|
||||
),
|
||||
CollectionIDs: []int64{collectionID},
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
return 0, errors.New(resp.Status.Reason)
|
||||
}
|
||||
|
||||
if len(resp.InMemoryPercentages) == 0 {
|
||||
return 0, errors.New("fail to show collections from the querycoord, no data")
|
||||
}
|
||||
return resp.InMemoryPercentages[0], nil
|
||||
}
|
||||
|
||||
func (node *Proxy) getPartitionProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest, collectionID int64) (int64, error) {
|
||||
IDs2Names := make(map[int64]string)
|
||||
partitionIDs := make([]int64, 0)
|
||||
for _, partitionName := range request.PartitionNames {
|
||||
partitionID, err := globalMetaCache.GetPartitionID(ctx, request.CollectionName, partitionName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
IDs2Names[partitionID] = partitionName
|
||||
partitionIDs = append(partitionIDs, partitionID)
|
||||
}
|
||||
resp, err := node.queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
|
||||
Base: commonpbutil.UpdateMsgBase(
|
||||
request.Base,
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
PartitionIDs: partitionIDs,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(resp.InMemoryPercentages) != len(partitionIDs) {
|
||||
return 0, errors.New("fail to show partitions from the querycoord, invalid data num")
|
||||
}
|
||||
var progress int64
|
||||
for _, p := range resp.InMemoryPercentages {
|
||||
progress += p
|
||||
}
|
||||
progress /= int64(len(partitionIDs))
|
||||
return progress, nil
|
||||
}
|
||||
|
||||
func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.GetLoadingProgressRequest) (*milvuspb.GetLoadingProgressResponse, error) {
|
||||
if !node.checkHealthy() {
|
||||
return &milvuspb.GetLoadingProgressResponse{Status: unhealthyStatus()}, nil
|
||||
@ -1496,9 +1440,9 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
|
||||
|
||||
getErrResponse := func(err error) *milvuspb.GetLoadingProgressResponse {
|
||||
log.Warn("fail to get loading progress",
|
||||
zap.Error(err),
|
||||
zap.String("collection_name", request.CollectionName),
|
||||
zap.Strings("partition_name", request.PartitionNames))
|
||||
zap.Strings("partition_name", request.PartitionNames),
|
||||
zap.Error(err))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
|
||||
return &milvuspb.GetLoadingProgressResponse{
|
||||
Status: &commonpb.Status{
|
||||
@ -1514,6 +1458,13 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
|
||||
if err != nil {
|
||||
return getErrResponse(err), nil
|
||||
}
|
||||
|
||||
if statesResp, err := node.queryCoord.GetComponentStates(ctx); err != nil {
|
||||
return getErrResponse(err), nil
|
||||
} else if statesResp.State == nil || statesResp.State.StateCode != commonpb.StateCode_Healthy {
|
||||
return getErrResponse(fmt.Errorf("the querycoord server isn't healthy, state: %v", statesResp.State)), nil
|
||||
}
|
||||
|
||||
msgBase := commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_SystemInfo),
|
||||
commonpbutil.WithMsgID(0),
|
||||
@ -1529,11 +1480,12 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
|
||||
|
||||
var progress int64
|
||||
if len(request.GetPartitionNames()) == 0 {
|
||||
if progress, err = node.getCollectionProgress(ctx, request, collectionID); err != nil {
|
||||
if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil {
|
||||
return getErrResponse(err), nil
|
||||
}
|
||||
} else {
|
||||
if progress, err = node.getPartitionProgress(ctx, request, collectionID); err != nil {
|
||||
if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
|
||||
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
|
||||
return getErrResponse(err), nil
|
||||
}
|
||||
}
|
||||
@ -1552,7 +1504,95 @@ func (node *Proxy) GetLoadingProgress(ctx context.Context, request *milvuspb.Get
|
||||
}
|
||||
|
||||
func (node *Proxy) GetLoadState(ctx context.Context, request *milvuspb.GetLoadStateRequest) (*milvuspb.GetLoadStateResponse, error) {
|
||||
return nil, nil
|
||||
if !node.checkHealthy() {
|
||||
return &milvuspb.GetLoadStateResponse{Status: unhealthyStatus()}, nil
|
||||
}
|
||||
method := "GetLoadState"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
sp, ctx := trace.StartSpanFromContextWithOperationName(ctx, "Proxy-GetLoadState")
|
||||
defer sp.Finish()
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.TotalLabel).Inc()
|
||||
log := log.Ctx(ctx)
|
||||
|
||||
log.Debug(
|
||||
rpcReceived(method),
|
||||
zap.Any("request", request))
|
||||
|
||||
getErrResponse := func(err error) *milvuspb.GetLoadStateResponse {
|
||||
log.Warn("fail to get load state",
|
||||
zap.String("collection_name", request.CollectionName),
|
||||
zap.Strings("partition_name", request.PartitionNames),
|
||||
zap.Error(err))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.FailLabel).Inc()
|
||||
return &milvuspb.GetLoadStateResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: err.Error(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if err := validateCollectionName(request.CollectionName); err != nil {
|
||||
return getErrResponse(err), nil
|
||||
}
|
||||
|
||||
if statesResp, err := node.queryCoord.GetComponentStates(ctx); err != nil {
|
||||
return getErrResponse(err), nil
|
||||
} else if statesResp.State == nil || statesResp.State.StateCode != commonpb.StateCode_Healthy {
|
||||
return getErrResponse(fmt.Errorf("the querycoord server isn't healthy, state: %v", statesResp.State)), nil
|
||||
}
|
||||
|
||||
successResponse := &milvuspb.GetLoadStateResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
},
|
||||
}
|
||||
defer func() {
|
||||
log.Debug(
|
||||
rpcDone(method),
|
||||
zap.Any("request", request))
|
||||
metrics.ProxyFunctionCall.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method, metrics.SuccessLabel).Inc()
|
||||
metrics.ProxyReqLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), method).Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
}()
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, request.CollectionName)
|
||||
if err != nil {
|
||||
successResponse.State = commonpb.LoadState_LoadStateNotExist
|
||||
return successResponse, nil
|
||||
}
|
||||
|
||||
msgBase := commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_SystemInfo),
|
||||
commonpbutil.WithMsgID(0),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
)
|
||||
if request.Base == nil {
|
||||
request.Base = msgBase
|
||||
} else {
|
||||
request.Base.MsgID = msgBase.MsgID
|
||||
request.Base.Timestamp = msgBase.Timestamp
|
||||
request.Base.SourceID = msgBase.SourceID
|
||||
}
|
||||
|
||||
var progress int64
|
||||
if len(request.GetPartitionNames()) == 0 {
|
||||
if progress, err = getCollectionProgress(ctx, node.queryCoord, request.GetBase(), collectionID); err != nil {
|
||||
successResponse.State = commonpb.LoadState_LoadStateNotLoad
|
||||
return successResponse, nil
|
||||
}
|
||||
} else {
|
||||
if progress, err = getPartitionProgress(ctx, node.queryCoord, request.GetBase(),
|
||||
request.GetPartitionNames(), request.GetCollectionName(), collectionID); err != nil {
|
||||
successResponse.State = commonpb.LoadState_LoadStateNotLoad
|
||||
return successResponse, nil
|
||||
}
|
||||
}
|
||||
if progress >= 100 {
|
||||
successResponse.State = commonpb.LoadState_LoadStateLoaded
|
||||
} else {
|
||||
successResponse.State = commonpb.LoadState_LoadStateLoading
|
||||
}
|
||||
return successResponse, nil
|
||||
}
|
||||
|
||||
// CreateIndex create index for collection.
|
||||
|
||||
@ -30,15 +30,33 @@ import (
|
||||
"github.com/golang/protobuf/proto"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/common"
|
||||
"github.com/milvus-io/milvus/internal/datacoord"
|
||||
"github.com/milvus-io/milvus/internal/datanode"
|
||||
grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord"
|
||||
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
|
||||
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
|
||||
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
|
||||
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
|
||||
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
|
||||
grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord"
|
||||
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
|
||||
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
|
||||
grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord"
|
||||
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
|
||||
"github.com/milvus-io/milvus/internal/indexcoord"
|
||||
"github.com/milvus-io/milvus/internal/indexnode"
|
||||
"github.com/milvus-io/milvus/internal/log"
|
||||
"github.com/milvus-io/milvus/internal/metrics"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
querycoord "github.com/milvus-io/milvus/internal/querycoordv2"
|
||||
"github.com/milvus-io/milvus/internal/querynode"
|
||||
"github.com/milvus-io/milvus/internal/rootcoord"
|
||||
"github.com/milvus-io/milvus/internal/util"
|
||||
"github.com/milvus-io/milvus/internal/util/crypto"
|
||||
@ -53,32 +71,11 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/internal/util/trace"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/proxypb"
|
||||
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
|
||||
|
||||
grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord"
|
||||
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
|
||||
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
|
||||
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
|
||||
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
|
||||
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
|
||||
grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord"
|
||||
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
|
||||
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
|
||||
grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord"
|
||||
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/datacoord"
|
||||
"github.com/milvus-io/milvus/internal/datanode"
|
||||
"github.com/milvus-io/milvus/internal/indexcoord"
|
||||
"github.com/milvus-io/milvus/internal/indexnode"
|
||||
querycoord "github.com/milvus-io/milvus/internal/querycoordv2"
|
||||
"github.com/milvus-io/milvus/internal/querynode"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -1050,6 +1047,16 @@ func TestProxy(t *testing.T) {
|
||||
// default partition
|
||||
assert.Equal(t, 2, len(resp.PartitionNames))
|
||||
|
||||
{
|
||||
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
|
||||
CollectionName: collectionName,
|
||||
PartitionNames: resp.PartitionNames,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
|
||||
}
|
||||
|
||||
// non-exist collection -> fail
|
||||
resp, err = proxy.ShowPartitions(ctx, &milvuspb.ShowPartitionsRequest{
|
||||
Base: nil,
|
||||
@ -1201,6 +1208,15 @@ func TestProxy(t *testing.T) {
|
||||
wg.Add(1)
|
||||
t.Run("load collection", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
{
|
||||
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
|
||||
CollectionName: collectionName,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
|
||||
}
|
||||
|
||||
resp, err := proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
@ -1307,6 +1323,15 @@ func TestProxy(t *testing.T) {
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode)
|
||||
assert.Equal(t, int64(0), progressResp.Progress)
|
||||
}
|
||||
|
||||
{
|
||||
stateResp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{
|
||||
CollectionName: otherCollectionName,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateNotExist, stateResp.State)
|
||||
}
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
@ -2516,6 +2541,14 @@ func TestProxy(t *testing.T) {
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("GetLoadState fail, unhealthy", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
resp, err := proxy.GetLoadState(ctx, &milvuspb.GetLoadStateRequest{})
|
||||
assert.NoError(t, err)
|
||||
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
|
||||
})
|
||||
|
||||
wg.Add(1)
|
||||
t.Run("CreateIndex fail, unhealthy", func(t *testing.T) {
|
||||
defer wg.Done()
|
||||
@ -3979,3 +4012,135 @@ func TestProxy_ListImportTasks(t *testing.T) {
|
||||
func TestProxy_GetStatistics(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestProxy_GetLoadState(t *testing.T) {
|
||||
originCache := globalMetaCache
|
||||
m := newMockCache()
|
||||
m.setGetIDFunc(func(ctx context.Context, collectionName string) (typeutil.UniqueID, error) {
|
||||
return 1, nil
|
||||
})
|
||||
m.setGetPartitionIDFunc(func(ctx context.Context, collectionName string, partitionName string) (typeutil.UniqueID, error) {
|
||||
return 2, nil
|
||||
})
|
||||
globalMetaCache = m
|
||||
defer func() {
|
||||
globalMetaCache = originCache
|
||||
}()
|
||||
|
||||
{
|
||||
q := NewQueryCoordMock()
|
||||
q.state.Store(commonpb.StateCode_Abnormal)
|
||||
proxy := &Proxy{queryCoord: q}
|
||||
proxy.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode)
|
||||
|
||||
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
|
||||
}
|
||||
|
||||
{
|
||||
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return nil, errors.New("test")
|
||||
}), SetQueryCoordShowPartitionsFunc(func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
|
||||
return nil, errors.New("test")
|
||||
}))
|
||||
q.state.Store(commonpb.StateCode_Healthy)
|
||||
proxy := &Proxy{queryCoord: q}
|
||||
proxy.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
|
||||
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
|
||||
|
||||
stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", PartitionNames: []string{"p1"}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
|
||||
|
||||
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
|
||||
assert.Equal(t, int64(0), progressResp.Progress)
|
||||
|
||||
progressResp, err = proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo", PartitionNames: []string{"p1"}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
|
||||
assert.Equal(t, int64(0), progressResp.Progress)
|
||||
}
|
||||
|
||||
{
|
||||
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
CollectionIDs: request.CollectionIDs,
|
||||
InMemoryPercentages: []int64{},
|
||||
}, nil
|
||||
}))
|
||||
q.state.Store(commonpb.StateCode_Healthy)
|
||||
proxy := &Proxy{queryCoord: q}
|
||||
proxy.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
|
||||
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateNotLoad, stateResp.State)
|
||||
|
||||
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, progressResp.Status.ErrorCode)
|
||||
}
|
||||
|
||||
{
|
||||
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
CollectionIDs: request.CollectionIDs,
|
||||
InMemoryPercentages: []int64{100},
|
||||
}, nil
|
||||
}))
|
||||
q.state.Store(commonpb.StateCode_Healthy)
|
||||
proxy := &Proxy{queryCoord: q}
|
||||
proxy.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
|
||||
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo", Base: &commonpb.MsgBase{}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateLoaded, stateResp.State)
|
||||
|
||||
stateResp, err = proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: ""})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, stateResp.Status.ErrorCode)
|
||||
|
||||
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode)
|
||||
assert.Equal(t, int64(100), progressResp.Progress)
|
||||
}
|
||||
|
||||
{
|
||||
q := NewQueryCoordMock(SetQueryCoordShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
CollectionIDs: request.CollectionIDs,
|
||||
InMemoryPercentages: []int64{50},
|
||||
}, nil
|
||||
}))
|
||||
q.state.Store(commonpb.StateCode_Healthy)
|
||||
proxy := &Proxy{queryCoord: q}
|
||||
proxy.stateCode.Store(commonpb.StateCode_Healthy)
|
||||
|
||||
stateResp, err := proxy.GetLoadState(context.Background(), &milvuspb.GetLoadStateRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, stateResp.Status.ErrorCode)
|
||||
assert.Equal(t, commonpb.LoadState_LoadStateLoading, stateResp.State)
|
||||
|
||||
progressResp, err := proxy.GetLoadingProgress(context.Background(), &milvuspb.GetLoadingProgressRequest{CollectionName: "foo"})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, progressResp.Status.ErrorCode)
|
||||
assert.Equal(t, int64(50), progressResp.Progress)
|
||||
}
|
||||
}
|
||||
|
||||
@ -24,6 +24,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/commonpbutil"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mq/msgstream"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
@ -941,3 +943,73 @@ func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstre
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func getCollectionProgress(ctx context.Context, queryCoord types.QueryCoord,
|
||||
msgBase *commonpb.MsgBase, collectionID int64) (int64, error) {
|
||||
resp, err := queryCoord.ShowCollections(ctx, &querypb.ShowCollectionsRequest{
|
||||
Base: commonpbutil.UpdateMsgBase(
|
||||
msgBase,
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_DescribeCollection),
|
||||
),
|
||||
CollectionIDs: []int64{collectionID},
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("fail to show collections", zap.Int64("collection_id", collectionID), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
|
||||
log.Warn("fail to show collections", zap.Int64("collection_id", collectionID),
|
||||
zap.String("reason", resp.Status.Reason))
|
||||
return 0, errors.New(resp.Status.Reason)
|
||||
}
|
||||
|
||||
if len(resp.InMemoryPercentages) == 0 {
|
||||
errMsg := "fail to show collections from the querycoord, no data"
|
||||
log.Warn(errMsg, zap.Int64("collection_id", collectionID))
|
||||
return 0, errors.New(errMsg)
|
||||
}
|
||||
return resp.InMemoryPercentages[0], nil
|
||||
}
|
||||
|
||||
func getPartitionProgress(ctx context.Context, queryCoord types.QueryCoord,
|
||||
msgBase *commonpb.MsgBase, partitionNames []string, collectionName string, collectionID int64) (int64, error) {
|
||||
IDs2Names := make(map[int64]string)
|
||||
partitionIDs := make([]int64, 0)
|
||||
for _, partitionName := range partitionNames {
|
||||
partitionID, err := globalMetaCache.GetPartitionID(ctx, collectionName, partitionName)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
IDs2Names[partitionID] = partitionName
|
||||
partitionIDs = append(partitionIDs, partitionID)
|
||||
}
|
||||
resp, err := queryCoord.ShowPartitions(ctx, &querypb.ShowPartitionsRequest{
|
||||
Base: commonpbutil.UpdateMsgBase(
|
||||
msgBase,
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
PartitionIDs: partitionIDs,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("fail to show partitions", zap.Int64("collection_id", collectionID),
|
||||
zap.String("collection_name", collectionName),
|
||||
zap.Strings("partition_names", partitionNames),
|
||||
zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
if len(resp.InMemoryPercentages) != len(partitionIDs) {
|
||||
errMsg := "fail to show partitions from the querycoord, invalid data num"
|
||||
log.Warn(errMsg, zap.Int64("collection_id", collectionID),
|
||||
zap.String("collection_name", collectionName),
|
||||
zap.Strings("partition_names", partitionNames))
|
||||
return 0, errors.New(errMsg)
|
||||
}
|
||||
var progress int64
|
||||
for _, p := range resp.InMemoryPercentages {
|
||||
progress += p
|
||||
}
|
||||
progress /= int64(len(partitionIDs))
|
||||
return progress, nil
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user