mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Cherry-pick from master pr: #45018 #45030 Related to #44761 Refactor proxy shard client management by creating a new internal/proxy/shardclient package. This improves code organization and modularity by: - Moving load balancing logic (LookAsideBalancer, RoundRobinBalancer) to shardclient package - Extracting shard client manager and related interfaces into separate package - Relocating shard leader management and client lifecycle code - Adding package documentation (README.md, OWNERS) - Updating proxy code to use the new shardclient package interfaces This change makes the shard client functionality more maintainable and better encapsulated, reducing coupling in the proxy layer. Also consolidates the proxy package's mockery generation to use a centralized `.mockery.yaml` configuration file, aligning with the pattern used by other packages like querycoordv2. Changes - **Makefile**: Replace multiple individual mockery commands with a single config-based invocation for `generate-mockery-proxy` target - **internal/proxy/.mockery.yaml**: Add mockery configuration defining all mock interfaces for proxy and proxy/shardclient packages - **Mock files**: Regenerate mocks using the new configuration: - `mock_cache.go`: Clean up by removing unused interface methods (credential, shard cache, policy methods) - `shardclient/mock_lb_balancer.go`: Update type comments (nodeInfo → NodeInfo) - `shardclient/mock_lb_policy.go`: Update formatting - `shardclient/mock_shardclient_manager.go`: Fix parameter naming consistency (nodeInfo1 → nodeInfo) - **task_search_test.go**: Remove obsolete mock expectations for deprecated cache methods Benefits - Centralized mockery configuration for easier maintenance - Consistent with other packages (querycoordv2, etc.) - Cleaner mock interfaces by removing unused methods - Better type consistency in generated mocks --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
b3e525609c
commit
a592cfc8b4
7
Makefile
7
Makefile
@ -478,12 +478,7 @@ generate-mockery-rootcoord: getdeps
|
||||
$(INSTALL_PATH)/mockery --name=IMetaTable --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=meta_table.go --with-expecter --outpkg=mockrootcoord
|
||||
|
||||
generate-mockery-proxy: getdeps
|
||||
$(INSTALL_PATH)/mockery --name=Cache --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_cache.go --structname=MockCache --with-expecter --outpkg=proxy --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=timestampAllocatorInterface --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_tso_test.go --structname=mockTimestampAllocator --with-expecter --outpkg=proxy --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=LBBalancer --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_balancer.go --structname=MockLBBalancer --with-expecter --outpkg=proxy --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=shardClientMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_shardclient_manager.go --structname=MockShardClientManager --with-expecter --outpkg=proxy --inpackage
|
||||
$(INSTALL_PATH)/mockery --name=channelsMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_channels_manager.go --structname=MockChannelsMgr --with-expecter --outpkg=proxy --inpackage
|
||||
$(INSTALL_PATH)/mockery --config $(PWD)/internal/proxy/.mockery.yaml
|
||||
|
||||
generate-mockery-querycoord: getdeps
|
||||
$(INSTALL_PATH)/mockery --config $(PWD)/internal/querycoordv2/.mockery.yaml
|
||||
|
||||
31
internal/proxy/.mockery.yaml
Normal file
31
internal/proxy/.mockery.yaml
Normal file
@ -0,0 +1,31 @@
|
||||
quiet: False
|
||||
with-expecter: True
|
||||
inpackage: True
|
||||
filename: "mock_{{.InterfaceNameSnake}}.go"
|
||||
mockname: "Mock{{.InterfaceName}}"
|
||||
outpkg: "{{.PackageName}}"
|
||||
dir: "{{.InterfaceDir}}"
|
||||
packages:
|
||||
github.com/milvus-io/milvus/internal/proxy:
|
||||
interfaces:
|
||||
Cache:
|
||||
timestampAllocatorInterface:
|
||||
config:
|
||||
mockname: mockTimestampAllocator
|
||||
filename: mock_tso_test.go
|
||||
channelsMgr:
|
||||
config:
|
||||
mockname: MockChannelsMgr
|
||||
filename: mock_channels_manager.go
|
||||
github.com/milvus-io/milvus/internal/proxy/shardclient:
|
||||
interfaces:
|
||||
LBPolicy:
|
||||
config:
|
||||
filename: mock_lb_policy.go
|
||||
LBBalancer:
|
||||
config:
|
||||
filename: mock_lb_balancer.go
|
||||
ShardClientMgr:
|
||||
config:
|
||||
filename: mock_shardclient_manager.go
|
||||
mockname: MockShardClientManager
|
||||
@ -40,8 +40,7 @@ func TestValidAuth(t *testing.T) {
|
||||
assert.False(t, res)
|
||||
// normal metadata
|
||||
mix := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, mix, mgr)
|
||||
err := InitMetaCache(ctx, mix)
|
||||
assert.NoError(t, err)
|
||||
res = validAuth(ctx, []string{crypto.Base64Encode("mockUser:mockPass")})
|
||||
assert.True(t, res)
|
||||
@ -72,8 +71,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
|
||||
assert.Error(t, err)
|
||||
// mock metacache
|
||||
queryCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, queryCoord, mgr)
|
||||
err = InitMetaCache(ctx, queryCoord)
|
||||
assert.NoError(t, err)
|
||||
// with invalid metadata
|
||||
md := metadata.Pairs("xxx", "yyy")
|
||||
|
||||
@ -137,12 +137,12 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
||||
if request.CollectionID != UniqueID(0) {
|
||||
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, request.GetBase().GetTimestamp(), msgType == commonpb.MsgType_DropCollection)
|
||||
for _, name := range aliasName {
|
||||
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
|
||||
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
|
||||
}
|
||||
}
|
||||
if collectionName != "" {
|
||||
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) // no need to return error, though collection may be not cached
|
||||
globalMetaCache.DeprecateShardCache(request.GetDbName(), collectionName)
|
||||
node.shardMgr.DeprecateShardCache(request.GetDbName(), collectionName)
|
||||
}
|
||||
log.Info("complete to invalidate collection meta cache with collection name", zap.String("type", request.GetBase().GetMsgType().String()))
|
||||
case commonpb.MsgType_LoadCollection, commonpb.MsgType_ReleaseCollection:
|
||||
@ -150,7 +150,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
||||
if request.CollectionID != UniqueID(0) {
|
||||
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, 0, false)
|
||||
for _, name := range aliasName {
|
||||
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
|
||||
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
|
||||
}
|
||||
}
|
||||
log.Info("complete to invalidate collection meta cache", zap.String("type", request.GetBase().GetMsgType().String()))
|
||||
@ -165,13 +165,16 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
||||
}
|
||||
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName)
|
||||
log.Info("complete to invalidate collection meta cache", zap.String("type", request.GetBase().GetMsgType().String()))
|
||||
case commonpb.MsgType_DropDatabase, commonpb.MsgType_AlterDatabase:
|
||||
case commonpb.MsgType_DropDatabase:
|
||||
node.shardMgr.RemoveDatabase(request.GetDbName())
|
||||
fallthrough
|
||||
case commonpb.MsgType_AlterDatabase:
|
||||
globalMetaCache.RemoveDatabase(ctx, request.GetDbName())
|
||||
case commonpb.MsgType_AlterCollection, commonpb.MsgType_AlterCollectionField:
|
||||
if request.CollectionID != UniqueID(0) {
|
||||
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, 0, false)
|
||||
for _, name := range aliasName {
|
||||
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
|
||||
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
|
||||
}
|
||||
}
|
||||
if collectionName != "" {
|
||||
@ -183,13 +186,13 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
|
||||
if request.CollectionID != UniqueID(0) {
|
||||
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, request.GetBase().GetTimestamp(), false)
|
||||
for _, name := range aliasName {
|
||||
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
|
||||
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
|
||||
}
|
||||
}
|
||||
|
||||
if collectionName != "" {
|
||||
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) // no need to return error, though collection may be not cached
|
||||
globalMetaCache.DeprecateShardCache(request.GetDbName(), collectionName)
|
||||
node.shardMgr.DeprecateShardCache(request.GetDbName(), collectionName)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -227,9 +230,8 @@ func (node *Proxy) InvalidateShardLeaderCache(ctx context.Context, request *prox
|
||||
|
||||
log.Info("received request to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
|
||||
|
||||
if globalMetaCache != nil {
|
||||
globalMetaCache.InvalidateShardLeaderCache(request.GetCollectionIDs())
|
||||
}
|
||||
node.shardMgr.InvalidateShardLeaderCache(request.GetCollectionIDs())
|
||||
|
||||
log.Info("complete to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
|
||||
|
||||
return merr.Success(), nil
|
||||
@ -2791,6 +2793,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
|
||||
mixCoord: node.mixCoord,
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
shardClientMgr: node.shardMgr,
|
||||
enableMaterializedView: node.enableMaterializedView,
|
||||
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
|
||||
}
|
||||
@ -3024,6 +3027,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
|
||||
mixCoord: node.mixCoord,
|
||||
node: node,
|
||||
lb: node.lbPolicy,
|
||||
shardClientMgr: node.shardMgr,
|
||||
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
|
||||
}
|
||||
|
||||
@ -3399,6 +3403,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
|
||||
request: request,
|
||||
mixCoord: node.mixCoord,
|
||||
lb: node.lbPolicy,
|
||||
shardclientMgr: node.shardMgr,
|
||||
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
|
||||
}
|
||||
|
||||
|
||||
@ -40,6 +40,7 @@ import (
|
||||
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
||||
mhttp "github.com/milvus-io/milvus/internal/http"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
@ -247,8 +248,8 @@ func TestProxy_ResourceGroup(t *testing.T) {
|
||||
node.sched.Start()
|
||||
defer node.sched.Close()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, qc, mgr)
|
||||
// mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, qc)
|
||||
|
||||
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
|
||||
@ -329,8 +330,8 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
|
||||
node.sched.Start()
|
||||
defer node.sched.Close()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, qc, mgr)
|
||||
// mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, qc)
|
||||
|
||||
t.Run("create resource group", func(t *testing.T) {
|
||||
resp, err := node.CreateResourceGroup(ctx, &milvuspb.CreateResourceGroupRequest{
|
||||
@ -1162,7 +1163,7 @@ func TestProxyDescribeCollection(t *testing.T) {
|
||||
}, nil).Maybe()
|
||||
mixCoord.On("DescribeCollection", mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe()
|
||||
var err error
|
||||
globalMetaCache, err = NewMetaCache(mixCoord, nil)
|
||||
globalMetaCache, err = NewMetaCache(mixCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
t.Run("not healthy", func(t *testing.T) {
|
||||
@ -1589,9 +1590,9 @@ func TestProxy_InvalidateShardLeaderCache(t *testing.T) {
|
||||
cacheBak := globalMetaCache
|
||||
defer func() { globalMetaCache = cacheBak }()
|
||||
// set expectations
|
||||
cache := NewMockCache(t)
|
||||
cache.EXPECT().InvalidateShardLeaderCache(mock.Anything)
|
||||
globalMetaCache = cache
|
||||
mockShardClientMgr := shardclient.NewMockShardClientManager(t)
|
||||
mockShardClientMgr.EXPECT().InvalidateShardLeaderCache(mock.Anything).Return()
|
||||
node.shardMgr = mockShardClientMgr
|
||||
|
||||
resp, err := node.InvalidateShardLeaderCache(context.TODO(), &proxypb.InvalidateShardLeaderCacheRequest{})
|
||||
assert.NoError(t, err)
|
||||
@ -1698,7 +1699,7 @@ func TestRunAnalyzer(t *testing.T) {
|
||||
fieldMap: fieldMap,
|
||||
}, nil)
|
||||
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
lb.EXPECT().ExecuteOneChannel(mock.Anything, mock.Anything).Return(nil)
|
||||
p.lbPolicy = lb
|
||||
|
||||
|
||||
@ -1,765 +0,0 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type LBPolicySuite struct {
|
||||
suite.Suite
|
||||
qc types.MixCoordClient
|
||||
qn *mocks.MockQueryNodeClient
|
||||
|
||||
mgr *MockShardClientManager
|
||||
lbBalancer *MockLBBalancer
|
||||
lbPolicy *LBPolicyImpl
|
||||
|
||||
nodeIDs []int64
|
||||
nodes []nodeInfo
|
||||
channels []string
|
||||
qnList []*mocks.MockQueryNode
|
||||
|
||||
collectionName string
|
||||
collectionID int64
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) SetupTest() {
|
||||
s.nodeIDs = make([]int64, 0)
|
||||
for i := 1; i <= 5; i++ {
|
||||
s.nodeIDs = append(s.nodeIDs, int64(i))
|
||||
s.nodes = append(s.nodes, nodeInfo{
|
||||
nodeID: int64(i),
|
||||
address: "localhost",
|
||||
serviceable: true,
|
||||
})
|
||||
}
|
||||
s.channels = []string{"channel1", "channel2"}
|
||||
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
qc := NewMixCoordMock()
|
||||
qc.GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: s.nodeIDs,
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
|
||||
Serviceable: []bool{true, true, true, true, true},
|
||||
},
|
||||
{
|
||||
ChannelName: s.channels[1],
|
||||
NodeIds: s.nodeIDs,
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
|
||||
Serviceable: []bool{true, true, true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
qc.ShowLoadPartitionsFunc = func(ctx context.Context, req *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) {
|
||||
return &querypb.ShowPartitionsResponse{
|
||||
Status: merr.Success(),
|
||||
PartitionIDs: []int64{1, 2, 3},
|
||||
}, nil
|
||||
}
|
||||
s.qc = qc
|
||||
|
||||
s.qn = mocks.NewMockQueryNodeClient(s.T())
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
s.mgr = NewMockShardClientManager(s.T())
|
||||
s.lbBalancer = NewMockLBBalancer(s.T())
|
||||
s.lbBalancer.EXPECT().Start(context.Background()).Maybe()
|
||||
s.lbPolicy = NewLBPolicyImpl(s.mgr)
|
||||
s.lbPolicy.Start(context.Background())
|
||||
s.lbPolicy.getBalancer = func() LBBalancer {
|
||||
return s.lbBalancer
|
||||
}
|
||||
|
||||
err := InitMetaCache(context.Background(), s.qc, s.mgr)
|
||||
s.NoError(err)
|
||||
|
||||
s.collectionName = "test_lb_policy"
|
||||
s.loadCollection()
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) loadCollection() {
|
||||
fieldName2Types := map[string]schemapb.DataType{
|
||||
testBoolField: schemapb.DataType_Bool,
|
||||
testInt32Field: schemapb.DataType_Int32,
|
||||
testInt64Field: schemapb.DataType_Int64,
|
||||
testFloatField: schemapb.DataType_Float,
|
||||
testDoubleField: schemapb.DataType_Double,
|
||||
testFloatVecField: schemapb.DataType_FloatVector,
|
||||
}
|
||||
if enableMultipleVectorFields {
|
||||
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
|
||||
}
|
||||
|
||||
schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
s.NoError(err)
|
||||
|
||||
ctx := context.Background()
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
CollectionName: s.collectionName,
|
||||
DbName: dbName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: common.DefaultShardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
mixCoord: s.qc,
|
||||
}
|
||||
|
||||
s.NoError(createColT.OnEnqueue())
|
||||
s.NoError(createColT.PreExecute(ctx))
|
||||
s.NoError(createColT.Execute(ctx))
|
||||
s.NoError(createColT.PostExecute(ctx))
|
||||
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collectionName)
|
||||
s.NoError(err)
|
||||
|
||||
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
SourceID: paramtable.GetNodeID(),
|
||||
},
|
||||
CollectionID: collectionID,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
s.collectionID = collectionID
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestSelectNode() {
|
||||
ctx := context.Background()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
// shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, &typeutil.UniqueSet{})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(5), targetNode.nodeID)
|
||||
|
||||
// test select node failed, then update shard leader cache and retry, expect success
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
// shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, &typeutil.UniqueSet{})
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), targetNode.nodeID)
|
||||
|
||||
// test select node always fails, expected failure
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
// shardLeaders: []nodeInfo{},
|
||||
nq: 1,
|
||||
}, &typeutil.UniqueSet{})
|
||||
s.ErrorIs(err, merr.ErrNodeNotAvailable)
|
||||
|
||||
// test all nodes has been excluded, expected clear excludeNodes and try to select node again
|
||||
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
// shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.ErrorIs(err, merr.ErrNodeNotAvailable)
|
||||
|
||||
// test get shard leaders failed, retry to select node failed
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return nil, merr.ErrServiceUnavailable
|
||||
}
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
// shardLeaders: s.nodes,
|
||||
nq: 1,
|
||||
}, &typeutil.UniqueSet{})
|
||||
s.ErrorIs(err, merr.ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
ctx := context.Background()
|
||||
|
||||
// test execute success
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test select node failed, expected error
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.ErrorIs(err, merr.ErrNodeNotAvailable)
|
||||
|
||||
// test get client failed, and retry failed, expected success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(2)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test exec failed, then retry success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
counter := 0
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
counter++
|
||||
if counter == 1 {
|
||||
return errors.New("fake error")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test exec timeout
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1)
|
||||
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.DeadlineExceeded)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
_, err := qn.Search(ctx, nil)
|
||||
return err
|
||||
},
|
||||
})
|
||||
s.True(merr.IsCanceledOrTimeout(err))
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestExecuteOneChannel() {
|
||||
ctx := context.Background()
|
||||
mockErr := errors.New("mock error")
|
||||
// test all channel success
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test get shard leader failed
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return nil, mockErr
|
||||
}
|
||||
err = s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.ErrorIs(err, mockErr)
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestExecute() {
|
||||
ctx := context.Background()
|
||||
mockErr := errors.New("mock error")
|
||||
// test all channel success
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test some channel failed
|
||||
counter := atomic.NewInt64(0)
|
||||
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
// succeed in first execute
|
||||
if counter.Add(1) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return mockErr
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
s.Equal(int64(6), counter.Load())
|
||||
|
||||
// test get shard leader failed
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return nil, mockErr
|
||||
}
|
||||
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
nq: 1,
|
||||
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.ErrorIs(err, mockErr)
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestUpdateCostMetrics() {
|
||||
s.lbBalancer.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything)
|
||||
s.lbPolicy.UpdateCostMetrics(1, &internalpb.CostAggregation{})
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestNewLBPolicy() {
|
||||
policy := NewLBPolicyImpl(s.mgr)
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
|
||||
policy.Close()
|
||||
|
||||
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin")
|
||||
policy = NewLBPolicyImpl(s.mgr)
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.RoundRobinBalancer")
|
||||
policy.Close()
|
||||
|
||||
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside")
|
||||
policy = NewLBPolicyImpl(s.mgr)
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
|
||||
policy.Close()
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestGetShard() {
|
||||
ctx := context.Background()
|
||||
|
||||
// ErrCollectionNotFullyLoaded is retriable, expected to retry until ctx done or success
|
||||
counter := atomic.NewInt64(0)
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
counter.Inc()
|
||||
return nil, merr.ErrCollectionNotFullyLoaded
|
||||
}
|
||||
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
log.Info("return rpc success")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
_, err := s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(0), counter.Load())
|
||||
|
||||
// ErrServiceUnavailable is not retriable, expected to fail fast
|
||||
counter.Store(0)
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
counter.Inc()
|
||||
return nil, merr.ErrCollectionNotLoaded
|
||||
}
|
||||
_, err = s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true)
|
||||
log.Info("check err", zap.Error(err))
|
||||
s.Error(err)
|
||||
s.Equal(int64(1), counter.Load())
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestSelectNodeEdgeCases() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test case 1: Empty shard leaders after refresh, should fail gracefully
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1)
|
||||
|
||||
// Setup mock to return empty shard leaders
|
||||
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: []int64{}, // Empty node list
|
||||
NodeAddrs: []string{},
|
||||
Serviceable: []bool{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
|
||||
_, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.Error(err)
|
||||
|
||||
log.Info("test case 1")
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
// Test case 2: Single replica scenario - exclude it, refresh shows same single replica, should clear and succeed
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1)
|
||||
|
||||
singleNodeList := []int64{1}
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: singleNodeList,
|
||||
NodeAddrs: []string{"localhost:9000"},
|
||||
Serviceable: []bool{true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude the single node
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(1), targetNode.nodeID)
|
||||
s.Equal(0, excludeNodes.Len()) // Should be cleared
|
||||
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
mixedNodeIDs := []int64{1, 2, 3}
|
||||
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil).Times(1)
|
||||
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: mixedNodeIDs,
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, false, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude node 1, node 3 should be available
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), targetNode.nodeID)
|
||||
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestGetShardLeaderList() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test normal scenario with cache
|
||||
channelList, err := s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true)
|
||||
s.NoError(err)
|
||||
s.Equal(len(s.channels), len(channelList))
|
||||
s.Contains(channelList, s.channels[0])
|
||||
s.Contains(channelList, s.channels[1])
|
||||
|
||||
// Test without cache - should refresh from coordinator
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
channelList, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, false)
|
||||
s.NoError(err)
|
||||
s.Equal(len(s.channels), len(channelList))
|
||||
|
||||
// Test error case - collection not loaded
|
||||
counter := atomic.NewInt64(0)
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
counter.Inc()
|
||||
return nil, merr.ErrCollectionNotLoaded
|
||||
}
|
||||
_, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true)
|
||||
s.Error(err)
|
||||
s.ErrorIs(err, merr.ErrCollectionNotLoaded)
|
||||
s.Equal(int64(1), counter.Load())
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestSelectNodeWithExcludeClearing() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test exclude nodes clearing when all replicas are excluded after cache refresh
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
// First attempt fails due to no candidates
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1)
|
||||
|
||||
// Setup mock to return only excluded nodes first, then same nodes for retry
|
||||
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: []int64{1, 2}, // All these will be excluded
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001"},
|
||||
Serviceable: []bool{true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
excludeNodes := typeutil.NewUniqueSet(int64(1), int64(2)) // Exclude all available nodes
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
}, &excludeNodes)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(int64(1), targetNode.nodeID)
|
||||
s.Equal(0, excludeNodes.Len()) // Should be cleared when all replicas were excluded
|
||||
|
||||
// Test exclude nodes NOT cleared when only partial replicas are excluded
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(2, nil).Times(1)
|
||||
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: []int64{1, 2, 3}, // Node 2 and 3 are still available
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Only exclude node 1
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
}, &excludeNodes)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(int64(2), targetNode.nodeID)
|
||||
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
|
||||
|
||||
// Test empty shard leaders scenario
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
|
||||
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
|
||||
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &successStatus,
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: s.channels[0],
|
||||
NodeIds: []int64{}, // Empty shard leaders
|
||||
NodeAddrs: []string{},
|
||||
Serviceable: []bool{},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1))
|
||||
_, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
db: dbName,
|
||||
collectionName: s.collectionName,
|
||||
collectionID: s.collectionID,
|
||||
channel: s.channels[0],
|
||||
nq: 1,
|
||||
}, &excludeNodes)
|
||||
|
||||
s.Error(err)
|
||||
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared for empty shard leaders
|
||||
}
|
||||
|
||||
func TestLBPolicySuite(t *testing.T) {
|
||||
suite.Run(t, new(LBPolicySuite))
|
||||
}
|
||||
@ -19,13 +19,11 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
@ -37,7 +35,6 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/conc"
|
||||
@ -68,11 +65,11 @@ type Cache interface {
|
||||
GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error)
|
||||
// GetCollectionSchema get collection's schema.
|
||||
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error)
|
||||
GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error)
|
||||
GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error)
|
||||
DeprecateShardCache(database, collectionName string)
|
||||
InvalidateShardLeaderCache(collections []int64)
|
||||
ListShardLocation() map[int64]nodeInfo
|
||||
// GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error)
|
||||
// GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error)
|
||||
// DeprecateShardCache(database, collectionName string)
|
||||
// InvalidateShardLeaderCache(collections []int64)
|
||||
// ListShardLocation() map[int64]nodeInfo
|
||||
RemoveCollection(ctx context.Context, database, collectionName string)
|
||||
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string
|
||||
|
||||
@ -288,58 +285,6 @@ func (info *collectionInfo) isCollectionCached() bool {
|
||||
return info != nil && info.collID != UniqueID(0) && info.schema != nil
|
||||
}
|
||||
|
||||
// shardLeaders wraps shard leader mapping for iteration.
|
||||
type shardLeaders struct {
|
||||
idx *atomic.Int64
|
||||
collectionID int64
|
||||
shardLeaders map[string][]nodeInfo
|
||||
}
|
||||
|
||||
func (sl *shardLeaders) Get(channel string) []nodeInfo {
|
||||
return sl.shardLeaders[channel]
|
||||
}
|
||||
|
||||
func (sl *shardLeaders) GetShardLeaderList() []string {
|
||||
return lo.Keys(sl.shardLeaders)
|
||||
}
|
||||
|
||||
type shardLeadersReader struct {
|
||||
leaders *shardLeaders
|
||||
idx int64
|
||||
}
|
||||
|
||||
// Shuffle returns the shuffled shard leader list.
|
||||
func (it shardLeadersReader) Shuffle() map[string][]nodeInfo {
|
||||
result := make(map[string][]nodeInfo)
|
||||
for channel, leaders := range it.leaders.shardLeaders {
|
||||
l := len(leaders)
|
||||
// shuffle all replica at random order
|
||||
shuffled := make([]nodeInfo, l)
|
||||
for i, randIndex := range rand.Perm(l) {
|
||||
shuffled[i] = leaders[randIndex]
|
||||
}
|
||||
|
||||
// make each copy has same probability to be first replica
|
||||
for index, leader := range shuffled {
|
||||
if leader == leaders[int(it.idx)%l] {
|
||||
shuffled[0], shuffled[index] = shuffled[index], shuffled[0]
|
||||
}
|
||||
}
|
||||
|
||||
result[channel] = shuffled
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetReader returns shuffer reader for shard leader.
|
||||
func (sl *shardLeaders) GetReader() shardLeadersReader {
|
||||
idx := sl.idx.Inc()
|
||||
return shardLeadersReader{
|
||||
leaders: sl,
|
||||
idx: idx,
|
||||
}
|
||||
}
|
||||
|
||||
// make sure MetaCache implements Cache.
|
||||
var _ Cache = (*MetaCache)(nil)
|
||||
|
||||
@ -347,18 +292,17 @@ var _ Cache = (*MetaCache)(nil)
|
||||
type MetaCache struct {
|
||||
mixCoord types.MixCoordClient
|
||||
|
||||
dbInfo map[string]*databaseInfo // database -> db_info
|
||||
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
|
||||
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
|
||||
dbInfo map[string]*databaseInfo // database -> db_info
|
||||
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
|
||||
|
||||
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
|
||||
privilegeInfos map[string]struct{} // privileges cache
|
||||
userToRoles map[string]map[string]struct{} // user to role cache
|
||||
mu sync.RWMutex
|
||||
credMut sync.RWMutex
|
||||
leaderMut sync.RWMutex
|
||||
shardMgr shardClientMgr
|
||||
sfGlobal conc.Singleflight[*collectionInfo]
|
||||
sfDB conc.Singleflight[*databaseInfo]
|
||||
|
||||
sfGlobal conc.Singleflight[*collectionInfo]
|
||||
sfDB conc.Singleflight[*databaseInfo]
|
||||
|
||||
IDStart int64
|
||||
IDCount int64
|
||||
@ -372,9 +316,9 @@ type MetaCache struct {
|
||||
var globalMetaCache Cache
|
||||
|
||||
// InitMetaCache initializes globalMetaCache
|
||||
func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr shardClientMgr) error {
|
||||
func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient) error {
|
||||
var err error
|
||||
globalMetaCache, err = NewMetaCache(mixCoord, shardMgr)
|
||||
globalMetaCache, err = NewMetaCache(mixCoord)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -390,14 +334,12 @@ func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr
|
||||
}
|
||||
|
||||
// NewMetaCache creates a MetaCache with provided RootCoord and QueryNode
|
||||
func NewMetaCache(mixCoord types.MixCoordClient, shardMgr shardClientMgr) (*MetaCache, error) {
|
||||
func NewMetaCache(mixCoord types.MixCoordClient) (*MetaCache, error) {
|
||||
return &MetaCache{
|
||||
mixCoord: mixCoord,
|
||||
dbInfo: map[string]*databaseInfo{},
|
||||
collInfo: map[string]map[string]*collectionInfo{},
|
||||
collLeader: map[string]map[string]*shardLeaders{},
|
||||
credMap: map[string]*internalpb.CredentialInfo{},
|
||||
shardMgr: shardMgr,
|
||||
privilegeInfos: map[string]struct{}{},
|
||||
userToRoles: map[string]map[string]struct{}{},
|
||||
collectionCacheVersion: make(map[UniqueID]uint64),
|
||||
@ -898,192 +840,12 @@ func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID Uniq
|
||||
return collNames
|
||||
}
|
||||
|
||||
func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
|
||||
method := "GetShard"
|
||||
// check cache first
|
||||
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
|
||||
if cacheShardLeaders == nil || !withCache {
|
||||
// refresh shard leader cache
|
||||
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheShardLeaders = newShardLeaders
|
||||
}
|
||||
|
||||
return cacheShardLeaders.Get(channel), nil
|
||||
}
|
||||
|
||||
func (m *MetaCache) GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) {
|
||||
method := "GetShardLeaderList"
|
||||
// check cache first
|
||||
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
|
||||
if cacheShardLeaders == nil || !withCache {
|
||||
// refresh shard leader cache
|
||||
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheShardLeaders = newShardLeaders
|
||||
}
|
||||
|
||||
return cacheShardLeaders.GetShardLeaderList(), nil
|
||||
}
|
||||
|
||||
func (m *MetaCache) getCachedShardLeaders(database, collectionName, caller string) *shardLeaders {
|
||||
m.leaderMut.RLock()
|
||||
var cacheShardLeaders *shardLeaders
|
||||
db, ok := m.collLeader[database]
|
||||
if !ok {
|
||||
cacheShardLeaders = nil
|
||||
} else {
|
||||
cacheShardLeaders = db[collectionName]
|
||||
}
|
||||
m.leaderMut.RUnlock()
|
||||
|
||||
if cacheShardLeaders != nil {
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheHitLabel).Inc()
|
||||
} else {
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheMissLabel).Inc()
|
||||
}
|
||||
|
||||
return cacheShardLeaders
|
||||
}
|
||||
|
||||
func (m *MetaCache) updateShardLocationCache(ctx context.Context, database, collectionName string, collectionID int64) (*shardLeaders, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("db", database),
|
||||
zap.String("collectionName", collectionName),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
|
||||
method := "updateShardLocationCache"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
defer metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).
|
||||
Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
req := &querypb.GetShardLeadersRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
WithUnserviceableShards: true,
|
||||
}
|
||||
resp, err := m.mixCoord.GetShardLeaders(ctx, req)
|
||||
if err := merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
|
||||
log.Error("failed to get shard locations",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||
|
||||
// convert shards map to string for logging
|
||||
if log.Logger.Level() == zap.DebugLevel {
|
||||
shardStr := make([]string, 0, len(shards))
|
||||
for channel, nodes := range shards {
|
||||
nodeStrs := make([]string, 0, len(nodes))
|
||||
for _, node := range nodes {
|
||||
nodeStrs = append(nodeStrs, node.String())
|
||||
}
|
||||
shardStr = append(shardStr, fmt.Sprintf("%s:[%s]", channel, strings.Join(nodeStrs, ", ")))
|
||||
}
|
||||
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
|
||||
}
|
||||
|
||||
newShardLeaders := &shardLeaders{
|
||||
collectionID: collectionID,
|
||||
shardLeaders: shards,
|
||||
idx: atomic.NewInt64(0),
|
||||
}
|
||||
|
||||
m.leaderMut.Lock()
|
||||
if _, ok := m.collLeader[database]; !ok {
|
||||
m.collLeader[database] = make(map[string]*shardLeaders)
|
||||
}
|
||||
m.collLeader[database][collectionName] = newShardLeaders
|
||||
m.leaderMut.Unlock()
|
||||
|
||||
return newShardLeaders, nil
|
||||
}
|
||||
|
||||
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]nodeInfo {
|
||||
shard2QueryNodes := make(map[string][]nodeInfo)
|
||||
|
||||
for _, leaders := range shardsLeaders {
|
||||
qns := make([]nodeInfo, len(leaders.GetNodeIds()))
|
||||
|
||||
for j := range qns {
|
||||
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j], leaders.GetServiceable()[j]}
|
||||
}
|
||||
|
||||
shard2QueryNodes[leaders.GetChannelName()] = qns
|
||||
}
|
||||
|
||||
return shard2QueryNodes
|
||||
}
|
||||
|
||||
// used for Garbage collection shard client
|
||||
func (m *MetaCache) ListShardLocation() map[int64]nodeInfo {
|
||||
m.leaderMut.RLock()
|
||||
defer m.leaderMut.RUnlock()
|
||||
shardLeaderInfo := make(map[int64]nodeInfo)
|
||||
|
||||
for _, dbInfo := range m.collLeader {
|
||||
for _, shardLeaders := range dbInfo {
|
||||
for _, nodeInfos := range shardLeaders.shardLeaders {
|
||||
for _, node := range nodeInfos {
|
||||
shardLeaderInfo[node.nodeID] = node
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return shardLeaderInfo
|
||||
}
|
||||
|
||||
// DeprecateShardCache clear the shard leader cache of a collection
|
||||
func (m *MetaCache) DeprecateShardCache(database, collectionName string) {
|
||||
log.Info("deprecate shard cache for collection", zap.String("collectionName", collectionName))
|
||||
m.leaderMut.Lock()
|
||||
defer m.leaderMut.Unlock()
|
||||
dbInfo, ok := m.collLeader[database]
|
||||
if ok {
|
||||
delete(dbInfo, collectionName)
|
||||
if len(dbInfo) == 0 {
|
||||
delete(m.collLeader, database)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache called when Shard leader balance happened
|
||||
func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
|
||||
log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections))
|
||||
m.leaderMut.Lock()
|
||||
defer m.leaderMut.Unlock()
|
||||
collectionSet := typeutil.NewUniqueSet(collections...)
|
||||
for dbName, dbInfo := range m.collLeader {
|
||||
for collectionName, shardLeaders := range dbInfo {
|
||||
if collectionSet.Contain(shardLeaders.collectionID) {
|
||||
delete(dbInfo, collectionName)
|
||||
}
|
||||
}
|
||||
if len(dbInfo) == 0 {
|
||||
delete(m.collLeader, dbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
|
||||
log.Ctx(ctx).Debug("remove database", zap.String("name", database))
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.collInfo, database)
|
||||
delete(m.dbInfo, database)
|
||||
m.mu.Unlock()
|
||||
|
||||
m.leaderMut.Lock()
|
||||
delete(m.collLeader, database)
|
||||
m.leaderMut.Unlock()
|
||||
}
|
||||
|
||||
func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool {
|
||||
|
||||
@ -27,8 +27,6 @@ import (
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
uatomic "go.uber.org/atomic"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
@ -47,7 +45,6 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
@ -966,8 +963,7 @@ func TestMetaCache_GetCollection(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id, err := globalMetaCache.GetCollectionID(ctx, dbName, "collection1")
|
||||
@ -1019,8 +1015,7 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// should be no data race.
|
||||
@ -1053,8 +1048,7 @@ func TestMetaCacheGetCollectionWithUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := mocks.NewMockMixCoordClient(t)
|
||||
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
t.Run("update with name", func(t *testing.T) {
|
||||
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
|
||||
@ -1131,8 +1125,7 @@ func TestMetaCache_InitCache(t *testing.T) {
|
||||
rootCoord := mocks.NewMockMixCoordClient(t)
|
||||
rootCoord.EXPECT().ShowLoadCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
|
||||
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil).Once()
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
@ -1142,8 +1135,7 @@ func TestMetaCache_InitCache(t *testing.T) {
|
||||
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(
|
||||
&internalpb.ListPolicyResponse{Status: merr.Status(errors.New("mock list policy error"))},
|
||||
nil).Once()
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
@ -1152,8 +1144,7 @@ func TestMetaCache_InitCache(t *testing.T) {
|
||||
rootCoord := mocks.NewMockMixCoordClient(t)
|
||||
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(
|
||||
nil, errors.New("mock list policy rpc errorr")).Once()
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
@ -1161,8 +1152,7 @@ func TestMetaCache_InitCache(t *testing.T) {
|
||||
func TestMetaCache_GetCollectionName(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
collection, err := globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), 1)
|
||||
@ -1213,8 +1203,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
|
||||
func TestMetaCache_GetCollectionFailure(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
rootCoord.Error = true
|
||||
|
||||
@ -1247,8 +1236,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
|
||||
func TestMetaCache_GetNonExistCollection(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id, err := globalMetaCache.GetCollectionID(ctx, dbName, "collection3")
|
||||
@ -1262,8 +1250,7 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) {
|
||||
func TestMetaCache_GetPartitionID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
id, err := globalMetaCache.GetPartitionID(ctx, dbName, "collection1", "par1")
|
||||
@ -1283,8 +1270,7 @@ func TestMetaCache_GetPartitionID(t *testing.T) {
|
||||
func TestMetaCache_ConcurrentTest1(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@ -1337,8 +1323,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
|
||||
func TestMetaCache_GetPartitionError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Test the case where ShowPartitionsResponse is not aligned
|
||||
@ -1362,133 +1347,23 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMetaCache_GetShard(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
collectionName = "collection1"
|
||||
collectionID = int64(1)
|
||||
)
|
||||
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("No collection in meta cache", func(t *testing.T) {
|
||||
shards, err := globalMetaCache.GetShard(ctx, true, dbName, "non-exists", 0, "channel-1")
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
||||
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
|
||||
shards, err := globalMetaCache.GetShard(ctx, false, dbName, collectionName, collectionID, "channel-1")
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
|
||||
t.Run("without shardLeaders in collection info", func(t *testing.T) {
|
||||
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Success(),
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
assert.Equal(t, 3, len(shards))
|
||||
|
||||
// get from cache
|
||||
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "not implemented",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, shards)
|
||||
assert.Equal(t, 3, len(shards))
|
||||
})
|
||||
t.Skip("GetShard has been moved to ShardClientMgr in shardclient package")
|
||||
// Test body removed - functionality moved to shardclient package
|
||||
}
|
||||
|
||||
func TestMetaCache_ClearShards(t *testing.T) {
|
||||
var (
|
||||
ctx = context.TODO()
|
||||
collectionName = "collection1"
|
||||
collectionID = int64(1)
|
||||
)
|
||||
|
||||
qc := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, qc, mgr)
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("Clear with no collection info", func(t *testing.T) {
|
||||
globalMetaCache.DeprecateShardCache(dbName, "collection_not_exist")
|
||||
})
|
||||
|
||||
t.Run("Clear valid collection empty cache", func(t *testing.T) {
|
||||
globalMetaCache.DeprecateShardCache(dbName, collectionName)
|
||||
})
|
||||
|
||||
t.Run("Clear valid collection valid cache", func(t *testing.T) {
|
||||
qc.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Success(),
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, shards)
|
||||
require.Equal(t, 3, len(shards))
|
||||
|
||||
globalMetaCache.DeprecateShardCache(dbName, collectionName)
|
||||
|
||||
qc.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "not implemented",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, shards)
|
||||
})
|
||||
t.Skip("DeprecateShardCache has been moved to ShardClientMgr in shardclient package")
|
||||
// Test body removed - functionality moved to shardclient package
|
||||
}
|
||||
|
||||
func TestMetaCache_PolicyInfo(t *testing.T) {
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
t.Run("InitMetaCache", func(t *testing.T) {
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return nil, errors.New("mock error")
|
||||
}
|
||||
err := InitMetaCache(context.Background(), client, mgr)
|
||||
err := InitMetaCache(context.Background(), client)
|
||||
assert.Error(t, err)
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
@ -1497,7 +1372,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
||||
PolicyInfos: []string{"policy1", "policy2", "policy3"},
|
||||
}, nil
|
||||
}
|
||||
err = InitMetaCache(context.Background(), client, mgr)
|
||||
err = InitMetaCache(context.Background(), client)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
@ -1509,7 +1384,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
||||
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2")},
|
||||
}, nil
|
||||
}
|
||||
err := InitMetaCache(context.Background(), client, mgr)
|
||||
err := InitMetaCache(context.Background(), client)
|
||||
assert.NoError(t, err)
|
||||
policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
|
||||
assert.Equal(t, 3, len(policyInfos))
|
||||
@ -1525,7 +1400,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
||||
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2")},
|
||||
}, nil
|
||||
}
|
||||
err := InitMetaCache(context.Background(), client, mgr)
|
||||
err := InitMetaCache(context.Background(), client)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"})
|
||||
@ -1566,7 +1441,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
||||
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")},
|
||||
}, nil
|
||||
}
|
||||
err := InitMetaCache(context.Background(), client, mgr)
|
||||
err := InitMetaCache(context.Background(), client)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"})
|
||||
@ -1601,8 +1476,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
|
||||
func TestMetaCache_RemoveCollection(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, shardMgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rootCoord.showLoadCollections = func(ctx context.Context, in *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
@ -1646,48 +1520,14 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGlobalMetaCache_ShuffleShardLeaders(t *testing.T) {
|
||||
shards := map[string][]nodeInfo{
|
||||
"channel-1": {
|
||||
{
|
||||
nodeID: 1,
|
||||
address: "localhost:9000",
|
||||
},
|
||||
{
|
||||
nodeID: 2,
|
||||
address: "localhost:9000",
|
||||
},
|
||||
{
|
||||
nodeID: 3,
|
||||
address: "localhost:9000",
|
||||
},
|
||||
},
|
||||
}
|
||||
sl := &shardLeaders{
|
||||
idx: uatomic.NewInt64(5),
|
||||
shardLeaders: shards,
|
||||
}
|
||||
|
||||
reader := sl.GetReader()
|
||||
result := reader.Shuffle()
|
||||
assert.Len(t, result["channel-1"], 3)
|
||||
assert.Equal(t, int64(1), result["channel-1"][0].nodeID)
|
||||
|
||||
reader = sl.GetReader()
|
||||
result = reader.Shuffle()
|
||||
assert.Len(t, result["channel-1"], 3)
|
||||
assert.Equal(t, int64(2), result["channel-1"][0].nodeID)
|
||||
|
||||
reader = sl.GetReader()
|
||||
result = reader.Shuffle()
|
||||
assert.Len(t, result["channel-1"], 3)
|
||||
assert.Equal(t, int64(3), result["channel-1"][0].nodeID)
|
||||
t.Skip("shardLeaders and nodeInfo have been moved to shardclient package")
|
||||
// Test body removed - functionality moved to shardclient package
|
||||
}
|
||||
|
||||
func TestMetaCache_Database(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, shardMgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
|
||||
|
||||
@ -1703,8 +1543,7 @@ func TestGetDatabaseInfo(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := mocks.NewMockMixCoordClient(t)
|
||||
shardMgr := newShardClientMgr()
|
||||
cache, err := NewMetaCache(rootCoord, shardMgr)
|
||||
cache, err := NewMetaCache(rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
|
||||
@ -1728,8 +1567,7 @@ func TestGetDatabaseInfo(t *testing.T) {
|
||||
t.Run("error", func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rootCoord := mocks.NewMockMixCoordClient(t)
|
||||
shardMgr := newShardClientMgr()
|
||||
cache, err := NewMetaCache(rootCoord, shardMgr)
|
||||
cache, err := NewMetaCache(rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
|
||||
@ -1742,7 +1580,6 @@ func TestGetDatabaseInfo(t *testing.T) {
|
||||
|
||||
func TestMetaCache_AllocID(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
shardMgr := newShardClientMgr()
|
||||
|
||||
t.Run("success", func(t *testing.T) {
|
||||
rootCoord := mocks.NewMockMixCoordClient(t)
|
||||
@ -1756,7 +1593,7 @@ func TestMetaCache_AllocID(t *testing.T) {
|
||||
PolicyInfos: []string{"policy1", "policy2", "policy3"},
|
||||
}, nil)
|
||||
|
||||
err := InitMetaCache(ctx, rootCoord, shardMgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
|
||||
|
||||
@ -1775,7 +1612,7 @@ func TestMetaCache_AllocID(t *testing.T) {
|
||||
PolicyInfos: []string{"policy1", "policy2", "policy3"},
|
||||
}, nil)
|
||||
|
||||
err := InitMetaCache(ctx, rootCoord, shardMgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
|
||||
|
||||
@ -1794,7 +1631,7 @@ func TestMetaCache_AllocID(t *testing.T) {
|
||||
PolicyInfos: []string{"policy1", "policy2", "policy3"},
|
||||
}, nil)
|
||||
|
||||
err := InitMetaCache(ctx, rootCoord, shardMgr)
|
||||
err := InitMetaCache(ctx, rootCoord)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
|
||||
|
||||
@ -1805,52 +1642,8 @@ func TestMetaCache_AllocID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
|
||||
paramtable.Init()
|
||||
paramtable.Get().Save(Params.ProxyCfg.ShardLeaderCacheInterval.Key, "1")
|
||||
|
||||
ctx := context.Background()
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
shardMgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, shardMgr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
rootCoord.showLoadCollections = func(ctx context.Context, in *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: merr.Success(),
|
||||
CollectionIDs: []UniqueID{1},
|
||||
InMemoryPercentages: []int64{100},
|
||||
}, nil
|
||||
}
|
||||
|
||||
called := uatomic.NewInt32(0)
|
||||
|
||||
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
called.Inc()
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Success(),
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
nodeInfos, err := globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodeInfos, 3)
|
||||
assert.Equal(t, called.Load(), int32(1))
|
||||
|
||||
globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
|
||||
assert.Equal(t, called.Load(), int32(1))
|
||||
|
||||
globalMetaCache.InvalidateShardLeaderCache([]int64{1})
|
||||
nodeInfos, err = globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
|
||||
assert.NoError(t, err)
|
||||
assert.Len(t, nodeInfos, 3)
|
||||
assert.Equal(t, called.Load(), int32(2))
|
||||
t.Skip("GetShard and InvalidateShardLeaderCache have been moved to ShardClientMgr in shardclient package")
|
||||
// Test body removed - functionality moved to shardclient package
|
||||
}
|
||||
|
||||
func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
|
||||
@ -2192,8 +1985,7 @@ func TestMetaCache_Parallel(t *testing.T) {
|
||||
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
|
||||
Status: merr.Success(),
|
||||
}, nil).Maybe()
|
||||
mgr := newShardClientMgr()
|
||||
cache, err := NewMetaCache(rootCoord, mgr)
|
||||
cache, err := NewMetaCache(rootCoord)
|
||||
assert.NoError(t, err)
|
||||
|
||||
cacheVersion := uint64(100)
|
||||
@ -2242,84 +2034,6 @@ func TestMetaCache_Parallel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestMetaCache_GetShardLeaderList(t *testing.T) {
|
||||
var (
|
||||
ctx = context.Background()
|
||||
collectionName = "collection1"
|
||||
collectionID = int64(1)
|
||||
)
|
||||
|
||||
rootCoord := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, rootCoord, mgr)
|
||||
require.Nil(t, err)
|
||||
|
||||
t.Run("No collection in meta cache", func(t *testing.T) {
|
||||
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, "non-exists", 0, true)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, channels)
|
||||
})
|
||||
|
||||
t.Run("Get channel list without cache", func(t *testing.T) {
|
||||
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: merr.Success(),
|
||||
Shards: []*querypb.ShardLeadersList{
|
||||
{
|
||||
ChannelName: "channel-1",
|
||||
NodeIds: []int64{1, 2, 3},
|
||||
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
{
|
||||
ChannelName: "channel-2",
|
||||
NodeIds: []int64{4, 5, 6},
|
||||
NodeAddrs: []string{"localhost:9003", "localhost:9004", "localhost:9005"},
|
||||
Serviceable: []bool{true, true, true},
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, false)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(channels))
|
||||
assert.Contains(t, channels, "channel-1")
|
||||
assert.Contains(t, channels, "channel-2")
|
||||
})
|
||||
|
||||
t.Run("Get channel list with cache", func(t *testing.T) {
|
||||
// First call should populate cache
|
||||
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(channels))
|
||||
|
||||
// Mock should return error but cache should be used
|
||||
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
return &querypb.GetShardLeadersResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "should not be called when using cache",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
channels, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(channels))
|
||||
assert.Contains(t, channels, "channel-1")
|
||||
assert.Contains(t, channels, "channel-2")
|
||||
})
|
||||
|
||||
t.Run("Error from coordinator", func(t *testing.T) {
|
||||
// Deprecate cache first
|
||||
globalMetaCache.DeprecateShardCache(dbName, collectionName)
|
||||
|
||||
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
|
||||
return nil, errors.New("coordinator error")
|
||||
}
|
||||
|
||||
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
|
||||
assert.Error(t, err)
|
||||
assert.Empty(t, channels)
|
||||
})
|
||||
t.Skip("GetShardLeaderList has been moved to ShardClientMgr in shardclient package")
|
||||
// Test body removed - functionality moved to shardclient package
|
||||
}
|
||||
|
||||
@ -55,8 +55,7 @@ func InitEmptyGlobalCache() {
|
||||
emptyMock := common.NewEmptyMockT()
|
||||
mixcoord := mocks.NewMockMixCoordClient(emptyMock)
|
||||
mixcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("collection not found"))
|
||||
mgr := newShardClientMgr()
|
||||
globalMetaCache, err = NewMetaCache(mixcoord, mgr)
|
||||
globalMetaCache, err = NewMetaCache(mixcoord)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@ -5,10 +5,7 @@ package proxy
|
||||
import (
|
||||
context "context"
|
||||
|
||||
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
|
||||
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
// MockCache is an autogenerated mock type for the Cache type
|
||||
@ -80,40 +77,6 @@ func (_c *MockCache_AllocID_Call) RunAndReturn(run func(context.Context) (int64,
|
||||
return _c
|
||||
}
|
||||
|
||||
// DeprecateShardCache provides a mock function with given fields: database, collectionName
|
||||
func (_m *MockCache) DeprecateShardCache(database string, collectionName string) {
|
||||
_m.Called(database, collectionName)
|
||||
}
|
||||
|
||||
// MockCache_DeprecateShardCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeprecateShardCache'
|
||||
type MockCache_DeprecateShardCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// DeprecateShardCache is a helper method to define mock.On call
|
||||
// - database string
|
||||
// - collectionName string
|
||||
func (_e *MockCache_Expecter) DeprecateShardCache(database interface{}, collectionName interface{}) *MockCache_DeprecateShardCache_Call {
|
||||
return &MockCache_DeprecateShardCache_Call{Call: _e.mock.On("DeprecateShardCache", database, collectionName)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_DeprecateShardCache_Call) Run(run func(database string, collectionName string)) *MockCache_DeprecateShardCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_DeprecateShardCache_Call) Return() *MockCache_DeprecateShardCache_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_DeprecateShardCache_Call) RunAndReturn(run func(string, string)) *MockCache_DeprecateShardCache_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCollectionID provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) GetCollectionID(ctx context.Context, database string, collectionName string) (int64, error) {
|
||||
ret := _m.Called(ctx, database, collectionName)
|
||||
@ -351,65 +314,6 @@ func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Cont
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetCredentialInfo provides a mock function with given fields: ctx, username
|
||||
func (_m *MockCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
|
||||
ret := _m.Called(ctx, username)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetCredentialInfo")
|
||||
}
|
||||
|
||||
var r0 *internalpb.CredentialInfo
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok {
|
||||
return rf(ctx, username)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
|
||||
r0 = rf(ctx, username)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(*internalpb.CredentialInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
|
||||
r1 = rf(ctx, username)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockCache_GetCredentialInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredentialInfo'
|
||||
type MockCache_GetCredentialInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetCredentialInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - username string
|
||||
func (_e *MockCache_Expecter) GetCredentialInfo(ctx interface{}, username interface{}) *MockCache_GetCredentialInfo_Call {
|
||||
return &MockCache_GetCredentialInfo_Call{Call: _e.mock.On("GetCredentialInfo", ctx, username)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCredentialInfo_Call) Run(run func(ctx context.Context, username string)) *MockCache_GetCredentialInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInfo, _a1 error) *MockCache_GetCredentialInfo_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockCache_GetCredentialInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetDatabaseInfo provides a mock function with given fields: ctx, database
|
||||
func (_m *MockCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) {
|
||||
ret := _m.Called(ctx, database)
|
||||
@ -709,227 +613,6 @@ func (_c *MockCache_GetPartitionsIndex_Call) RunAndReturn(run func(context.Conte
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetPrivilegeInfo provides a mock function with given fields: ctx
|
||||
func (_m *MockCache) GetPrivilegeInfo(ctx context.Context) []string {
|
||||
ret := _m.Called(ctx)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetPrivilegeInfo")
|
||||
}
|
||||
|
||||
var r0 []string
|
||||
if rf, ok := ret.Get(0).(func(context.Context) []string); ok {
|
||||
r0 = rf(ctx)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockCache_GetPrivilegeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivilegeInfo'
|
||||
type MockCache_GetPrivilegeInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetPrivilegeInfo is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
func (_e *MockCache_Expecter) GetPrivilegeInfo(ctx interface{}) *MockCache_GetPrivilegeInfo_Call {
|
||||
return &MockCache_GetPrivilegeInfo_Call{Call: _e.mock.On("GetPrivilegeInfo", ctx)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetPrivilegeInfo_Call) Run(run func(ctx context.Context)) *MockCache_GetPrivilegeInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockCache_GetPrivilegeInfo_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockCache_GetPrivilegeInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetShard provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID, channel
|
||||
func (_m *MockCache) GetShard(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
|
||||
ret := _m.Called(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetShard")
|
||||
}
|
||||
|
||||
var r0 []nodeInfo
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)); ok {
|
||||
return rf(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) []nodeInfo); ok {
|
||||
r0 = rf(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]nodeInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64, string) error); ok {
|
||||
r1 = rf(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockCache_GetShard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShard'
|
||||
type MockCache_GetShard_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetShard is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - withCache bool
|
||||
// - database string
|
||||
// - collectionName string
|
||||
// - collectionID int64
|
||||
// - channel string
|
||||
func (_e *MockCache_Expecter) GetShard(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}, channel interface{}) *MockCache_GetShard_Call {
|
||||
return &MockCache_GetShard_Call{Call: _e.mock.On("GetShard", ctx, withCache, database, collectionName, collectionID, channel)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShard_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string)) *MockCache_GetShard_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64), args[5].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShard_Call) Return(_a0 []nodeInfo, _a1 error) *MockCache_GetShard_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShard_Call) RunAndReturn(run func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)) *MockCache_GetShard_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetShardLeaderList provides a mock function with given fields: ctx, database, collectionName, collectionID, withCache
|
||||
func (_m *MockCache) GetShardLeaderList(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool) ([]string, error) {
|
||||
ret := _m.Called(ctx, database, collectionName, collectionID, withCache)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetShardLeaderList")
|
||||
}
|
||||
|
||||
var r0 []string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) ([]string, error)); ok {
|
||||
return rf(ctx, database, collectionName, collectionID, withCache)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) []string); ok {
|
||||
r0 = rf(ctx, database, collectionName, collectionID, withCache)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, bool) error); ok {
|
||||
r1 = rf(ctx, database, collectionName, collectionID, withCache)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockCache_GetShardLeaderList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShardLeaderList'
|
||||
type MockCache_GetShardLeaderList_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetShardLeaderList is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - database string
|
||||
// - collectionName string
|
||||
// - collectionID int64
|
||||
// - withCache bool
|
||||
func (_e *MockCache_Expecter) GetShardLeaderList(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}, withCache interface{}) *MockCache_GetShardLeaderList_Call {
|
||||
return &MockCache_GetShardLeaderList_Call{Call: _e.mock.On("GetShardLeaderList", ctx, database, collectionName, collectionID, withCache)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShardLeaderList_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool)) *MockCache_GetShardLeaderList_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64), args[4].(bool))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShardLeaderList_Call) Return(_a0 []string, _a1 error) *MockCache_GetShardLeaderList_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetShardLeaderList_Call) RunAndReturn(run func(context.Context, string, string, int64, bool) ([]string, error)) *MockCache_GetShardLeaderList_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetUserRole provides a mock function with given fields: username
|
||||
func (_m *MockCache) GetUserRole(username string) []string {
|
||||
ret := _m.Called(username)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetUserRole")
|
||||
}
|
||||
|
||||
var r0 []string
|
||||
if rf, ok := ret.Get(0).(func(string) []string); ok {
|
||||
r0 = rf(username)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockCache_GetUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserRole'
|
||||
type MockCache_GetUserRole_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetUserRole is a helper method to define mock.On call
|
||||
// - username string
|
||||
func (_e *MockCache_Expecter) GetUserRole(username interface{}) *MockCache_GetUserRole_Call {
|
||||
return &MockCache_GetUserRole_Call{Call: _e.mock.On("GetUserRole", username)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetUserRole_Call) Run(run func(username string)) *MockCache_GetUserRole_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetUserRole_Call) Return(_a0 []string) *MockCache_GetUserRole_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockCache_GetUserRole_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// HasDatabase provides a mock function with given fields: ctx, database
|
||||
func (_m *MockCache) HasDatabase(ctx context.Context, database string) bool {
|
||||
ret := _m.Called(ctx, database)
|
||||
@ -977,166 +660,6 @@ func (_c *MockCache_HasDatabase_Call) RunAndReturn(run func(context.Context, str
|
||||
return _c
|
||||
}
|
||||
|
||||
// InitPolicyInfo provides a mock function with given fields: info, userRoles
|
||||
func (_m *MockCache) InitPolicyInfo(info []string, userRoles []string) {
|
||||
_m.Called(info, userRoles)
|
||||
}
|
||||
|
||||
// MockCache_InitPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitPolicyInfo'
|
||||
type MockCache_InitPolicyInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InitPolicyInfo is a helper method to define mock.On call
|
||||
// - info []string
|
||||
// - userRoles []string
|
||||
func (_e *MockCache_Expecter) InitPolicyInfo(info interface{}, userRoles interface{}) *MockCache_InitPolicyInfo_Call {
|
||||
return &MockCache_InitPolicyInfo_Call{Call: _e.mock.On("InitPolicyInfo", info, userRoles)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_InitPolicyInfo_Call) Run(run func(info []string, userRoles []string)) *MockCache_InitPolicyInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]string), args[1].([]string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_InitPolicyInfo_Call) Return() *MockCache_InitPolicyInfo_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockCache_InitPolicyInfo_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache provides a mock function with given fields: collections
|
||||
func (_m *MockCache) InvalidateShardLeaderCache(collections []int64) {
|
||||
_m.Called(collections)
|
||||
}
|
||||
|
||||
// MockCache_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
|
||||
type MockCache_InvalidateShardLeaderCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache is a helper method to define mock.On call
|
||||
// - collections []int64
|
||||
func (_e *MockCache_Expecter) InvalidateShardLeaderCache(collections interface{}) *MockCache_InvalidateShardLeaderCache_Call {
|
||||
return &MockCache_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", collections)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_InvalidateShardLeaderCache_Call) Run(run func(collections []int64)) *MockCache_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_InvalidateShardLeaderCache_Call) Return() *MockCache_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int64)) *MockCache_InvalidateShardLeaderCache_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListShardLocation provides a mock function with no fields
|
||||
func (_m *MockCache) ListShardLocation() map[int64]nodeInfo {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListShardLocation")
|
||||
}
|
||||
|
||||
var r0 map[int64]nodeInfo
|
||||
if rf, ok := ret.Get(0).(func() map[int64]nodeInfo); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[int64]nodeInfo)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockCache_ListShardLocation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListShardLocation'
|
||||
type MockCache_ListShardLocation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListShardLocation is a helper method to define mock.On call
|
||||
func (_e *MockCache_Expecter) ListShardLocation() *MockCache_ListShardLocation_Call {
|
||||
return &MockCache_ListShardLocation_Call{Call: _e.mock.On("ListShardLocation")}
|
||||
}
|
||||
|
||||
func (_c *MockCache_ListShardLocation_Call) Run(run func()) *MockCache_ListShardLocation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_ListShardLocation_Call) Return(_a0 map[int64]nodeInfo) *MockCache_ListShardLocation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_ListShardLocation_Call) RunAndReturn(run func() map[int64]nodeInfo) *MockCache_ListShardLocation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RefreshPolicyInfo provides a mock function with given fields: op
|
||||
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
|
||||
ret := _m.Called(op)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for RefreshPolicyInfo")
|
||||
}
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(typeutil.CacheOp) error); ok {
|
||||
r0 = rf(op)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockCache_RefreshPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfo'
|
||||
type MockCache_RefreshPolicyInfo_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RefreshPolicyInfo is a helper method to define mock.On call
|
||||
// - op typeutil.CacheOp
|
||||
func (_e *MockCache_Expecter) RefreshPolicyInfo(op interface{}) *MockCache_RefreshPolicyInfo_Call {
|
||||
return &MockCache_RefreshPolicyInfo_Call{Call: _e.mock.On("RefreshPolicyInfo", op)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_RefreshPolicyInfo_Call) Run(run func(op typeutil.CacheOp)) *MockCache_RefreshPolicyInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(typeutil.CacheOp))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockCache_RefreshPolicyInfo_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockCache_RefreshPolicyInfo_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveCollection provides a mock function with given fields: ctx, database, collectionName
|
||||
func (_m *MockCache) RemoveCollection(ctx context.Context, database string, collectionName string) {
|
||||
_m.Called(ctx, database, collectionName)
|
||||
@ -1223,39 +746,6 @@ func (_c *MockCache_RemoveCollectionsByID_Call) RunAndReturn(run func(context.Co
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveCredential provides a mock function with given fields: username
|
||||
func (_m *MockCache) RemoveCredential(username string) {
|
||||
_m.Called(username)
|
||||
}
|
||||
|
||||
// MockCache_RemoveCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCredential'
|
||||
type MockCache_RemoveCredential_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RemoveCredential is a helper method to define mock.On call
|
||||
// - username string
|
||||
func (_e *MockCache_Expecter) RemoveCredential(username interface{}) *MockCache_RemoveCredential_Call {
|
||||
return &MockCache_RemoveCredential_Call{Call: _e.mock.On("RemoveCredential", username)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemoveCredential_Call) Run(run func(username string)) *MockCache_RemoveCredential_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemoveCredential_Call) Return() *MockCache_RemoveCredential_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockCache_RemoveCredential_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveDatabase provides a mock function with given fields: ctx, database
|
||||
func (_m *MockCache) RemoveDatabase(ctx context.Context, database string) {
|
||||
_m.Called(ctx, database)
|
||||
@ -1290,39 +780,6 @@ func (_c *MockCache_RemoveDatabase_Call) RunAndReturn(run func(context.Context,
|
||||
return _c
|
||||
}
|
||||
|
||||
// UpdateCredential provides a mock function with given fields: credInfo
|
||||
func (_m *MockCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
|
||||
_m.Called(credInfo)
|
||||
}
|
||||
|
||||
// MockCache_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential'
|
||||
type MockCache_UpdateCredential_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// UpdateCredential is a helper method to define mock.On call
|
||||
// - credInfo *internalpb.CredentialInfo
|
||||
func (_e *MockCache_Expecter) UpdateCredential(credInfo interface{}) *MockCache_UpdateCredential_Call {
|
||||
return &MockCache_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", credInfo)}
|
||||
}
|
||||
|
||||
func (_c *MockCache_UpdateCredential_Call) Run(run func(credInfo *internalpb.CredentialInfo)) *MockCache_UpdateCredential_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(*internalpb.CredentialInfo))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_UpdateCredential_Call) Return() *MockCache_UpdateCredential_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockCache_UpdateCredential_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockCache creates a new instance of MockCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockCache(t interface {
|
||||
|
||||
@ -1,193 +0,0 @@
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
types "github.com/milvus-io/milvus/internal/types"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockShardClientManager is an autogenerated mock type for the shardClientMgr type
|
||||
type MockShardClientManager struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockShardClientManager_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockShardClientManager) EXPECT() *MockShardClientManager_Expecter {
|
||||
return &MockShardClientManager_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Close provides a mock function with no fields
|
||||
func (_m *MockShardClientManager) Close() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockShardClientManager_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
|
||||
type MockShardClientManager_Close_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Close is a helper method to define mock.On call
|
||||
func (_e *MockShardClientManager_Expecter) Close() *MockShardClientManager_Close_Call {
|
||||
return &MockShardClientManager_Close_Call{Call: _e.mock.On("Close")}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Close_Call) Run(run func()) *MockShardClientManager_Close_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Close_Call) Return() *MockShardClientManager_Close_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShardClientManager_Close_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetClient provides a mock function with given fields: ctx, nodeInfo1
|
||||
func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeInfo1 nodeInfo) (types.QueryNodeClient, error) {
|
||||
ret := _m.Called(ctx, nodeInfo1)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetClient")
|
||||
}
|
||||
|
||||
var r0 types.QueryNodeClient
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) (types.QueryNodeClient, error)); ok {
|
||||
return rf(ctx, nodeInfo1)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) types.QueryNodeClient); ok {
|
||||
r0 = rf(ctx, nodeInfo1)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(types.QueryNodeClient)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, nodeInfo) error); ok {
|
||||
r1 = rf(ctx, nodeInfo1)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardClientManager_GetClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetClient'
|
||||
type MockShardClientManager_GetClient_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetClient is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - nodeInfo1 nodeInfo
|
||||
func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeInfo1 interface{}) *MockShardClientManager_GetClient_Call {
|
||||
return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeInfo1)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeInfo1 nodeInfo)) *MockShardClientManager_GetClient_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(nodeInfo))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClient, _a1 error) *MockShardClientManager_GetClient_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, nodeInfo) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetClientCreatorFunc provides a mock function with given fields: creator
|
||||
func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
|
||||
_m.Called(creator)
|
||||
}
|
||||
|
||||
// MockShardClientManager_SetClientCreatorFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetClientCreatorFunc'
|
||||
type MockShardClientManager_SetClientCreatorFunc_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SetClientCreatorFunc is a helper method to define mock.On call
|
||||
// - creator queryNodeCreatorFunc
|
||||
func (_e *MockShardClientManager_Expecter) SetClientCreatorFunc(creator interface{}) *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
return &MockShardClientManager_SetClientCreatorFunc_Call{Call: _e.mock.On("SetClientCreatorFunc", creator)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Run(run func(creator queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(queryNodeCreatorFunc))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Return() *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run func(queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Start provides a mock function with no fields
|
||||
func (_m *MockShardClientManager) Start() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockShardClientManager_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
|
||||
type MockShardClientManager_Start_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Start is a helper method to define mock.On call
|
||||
func (_e *MockShardClientManager_Expecter) Start() *MockShardClientManager_Start_Call {
|
||||
return &MockShardClientManager_Start_Call{Call: _e.mock.On("Start")}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Start_Call) Run(run func()) *MockShardClientManager_Start_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Start_Call) Return() *MockShardClientManager_Start_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Start_Call) RunAndReturn(run func()) *MockShardClientManager_Start_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockShardClientManager creates a new instance of MockShardClientManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockShardClientManager(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockShardClientManager {
|
||||
mock := &MockShardClientManager{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
@ -119,7 +119,7 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
|
||||
|
||||
mix := NewMixCoordMock()
|
||||
|
||||
err := InitMetaCache(ctx, mix, nil)
|
||||
err := InitMetaCache(ctx, mix)
|
||||
assert.NoError(t, err)
|
||||
|
||||
idAllocator, err := allocator.NewIDAllocator(ctx, mix, paramtable.GetNodeID())
|
||||
|
||||
@ -49,7 +49,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
|
||||
ctx = GetContext(context.Background(), "alice:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return &internalpb.ListPolicyResponse{
|
||||
@ -80,7 +79,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = InitMetaCache(ctx, client, mgr)
|
||||
err = InitMetaCache(ctx, client)
|
||||
assert.NoError(t, err)
|
||||
_, err = PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{
|
||||
DbName: "db_test",
|
||||
@ -218,7 +217,6 @@ func TestResourceGroupPrivilege(t *testing.T) {
|
||||
|
||||
ctx = GetContext(context.Background(), "fooo:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return &internalpb.ListPolicyResponse{
|
||||
@ -236,7 +234,7 @@ func TestResourceGroupPrivilege(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
InitMetaCache(ctx, client, mgr)
|
||||
InitMetaCache(ctx, client)
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{
|
||||
ResourceGroup: "rg",
|
||||
@ -274,7 +272,6 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
var err error
|
||||
ctx = GetContext(context.Background(), "fooo:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return &internalpb.ListPolicyResponse{
|
||||
@ -287,7 +284,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
InitMetaCache(ctx, client, mgr)
|
||||
InitMetaCache(ctx, client)
|
||||
defer privilege.CleanPrivilegeCache()
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||
@ -331,7 +328,6 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
var err error
|
||||
ctx = GetContext(context.Background(), "fooo:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return &internalpb.ListPolicyResponse{
|
||||
@ -344,7 +340,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
InitMetaCache(ctx, client, mgr)
|
||||
InitMetaCache(ctx, client)
|
||||
defer privilege.CleanPrivilegeCache()
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||
@ -388,7 +384,6 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
var err error
|
||||
ctx = GetContext(context.Background(), "fooo:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return &internalpb.ListPolicyResponse{
|
||||
@ -401,7 +396,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
InitMetaCache(ctx, client, mgr)
|
||||
InitMetaCache(ctx, client)
|
||||
defer privilege.CleanPrivilegeCache()
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||
@ -491,7 +486,6 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
var err error
|
||||
ctx = GetContext(context.Background(), "fooo:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return &internalpb.ListPolicyResponse{
|
||||
@ -504,7 +498,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
InitMetaCache(ctx, client, mgr)
|
||||
InitMetaCache(ctx, client)
|
||||
defer privilege.CleanPrivilegeCache()
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
|
||||
@ -558,7 +552,6 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
var err error
|
||||
ctx = GetContext(context.Background(), "fooo:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
|
||||
return &internalpb.ListPolicyResponse{
|
||||
@ -571,7 +564,7 @@ func TestPrivilegeGroup(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
InitMetaCache(ctx, client, mgr)
|
||||
InitMetaCache(ctx, client)
|
||||
defer privilege.CleanPrivilegeCache()
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{})
|
||||
@ -599,7 +592,6 @@ func TestBuiltinPrivilegeGroup(t *testing.T) {
|
||||
var err error
|
||||
ctx := GetContext(context.Background(), "fooo:123456")
|
||||
client := &MockMixCoordClientInterface{}
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
policies := []string{}
|
||||
for _, priv := range Params.RbacConfig.GetDefaultPrivilegeGroup("ClusterReadOnly").Privileges {
|
||||
@ -615,7 +607,7 @@ func TestBuiltinPrivilegeGroup(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
InitMetaCache(ctx, client, mgr)
|
||||
InitMetaCache(ctx, client)
|
||||
defer privilege.CleanPrivilegeCache()
|
||||
|
||||
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{})
|
||||
|
||||
@ -34,6 +34,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/proxy/connection"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/hookutil"
|
||||
@ -96,7 +97,7 @@ type Proxy struct {
|
||||
metricsCacheManager *metricsinfo.MetricsCacheManager
|
||||
|
||||
session *sessionutil.Session
|
||||
shardMgr shardClientMgr
|
||||
shardMgr shardclient.ShardClientMgr
|
||||
|
||||
searchResultCh chan *internalpb.SearchResults
|
||||
|
||||
@ -105,7 +106,7 @@ type Proxy struct {
|
||||
closeCallbacks []func()
|
||||
|
||||
// for load balance in replicas
|
||||
lbPolicy LBPolicy
|
||||
lbPolicy shardclient.LBPolicy
|
||||
|
||||
// resource manager
|
||||
resourceManager resource.Manager
|
||||
@ -124,17 +125,14 @@ func NewProxy(ctx context.Context, _ dependency.Factory) (*Proxy, error) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
ctx1, cancel := context.WithCancel(ctx)
|
||||
n := 1024 // better to be configurable
|
||||
mgr := newShardClientMgr()
|
||||
lbPolicy := NewLBPolicyImpl(mgr)
|
||||
lbPolicy.Start(ctx)
|
||||
resourceManager := resource.NewManager(10*time.Second, 20*time.Second, make(map[string]time.Duration))
|
||||
node := &Proxy{
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
shardMgr: mgr,
|
||||
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
|
||||
lbPolicy: lbPolicy,
|
||||
ctx: ctx1,
|
||||
cancel: cancel,
|
||||
searchResultCh: make(chan *internalpb.SearchResults, n),
|
||||
// shardMgr: mgr,
|
||||
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
|
||||
// lbPolicy: lbPolicy,
|
||||
resourceManager: resourceManager,
|
||||
slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
|
||||
}
|
||||
@ -247,12 +245,15 @@ func (node *Proxy) Init() error {
|
||||
node.metricsCacheManager = metricsinfo.NewMetricsCacheManager()
|
||||
log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
if err := InitMetaCache(node.ctx, node.mixCoord, node.shardMgr); err != nil {
|
||||
if err := InitMetaCache(node.ctx, node.mixCoord); err != nil {
|
||||
log.Warn("failed to init meta cache", zap.String("role", typeutil.ProxyRole), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
log.Debug("init meta cache done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
node.shardMgr = shardclient.NewShardClientMgr(node.mixCoord)
|
||||
node.lbPolicy = shardclient.NewLBPolicyImpl(node.shardMgr)
|
||||
|
||||
node.enableMaterializedView = Params.CommonCfg.EnableMaterializedView.GetAsBool()
|
||||
|
||||
// Enable internal rand pool for UUIDv4 generation
|
||||
@ -273,6 +274,8 @@ func (node *Proxy) Start() error {
|
||||
node.shardMgr.Start()
|
||||
log.Debug("start shard client manager done", zap.String("role", typeutil.ProxyRole))
|
||||
|
||||
node.lbPolicy.Start(node.ctx)
|
||||
|
||||
if err := node.sched.Start(); err != nil {
|
||||
log.Warn("failed to start task scheduler", zap.String("role", typeutil.ProxyRole), zap.Error(err))
|
||||
return err
|
||||
|
||||
@ -52,6 +52,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/util/componentutil"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/sessionutil"
|
||||
@ -74,6 +75,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
_ "github.com/milvus-io/milvus/pkg/v2/util/symbolizer" // support symbolizer and crash dump
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
@ -1037,7 +1039,11 @@ func TestProxy(t *testing.T) {
|
||||
proxy.SetMixCoordClient(rootCoordClient)
|
||||
log.Info("Proxy set mix coordinator client")
|
||||
|
||||
proxy.SetQueryNodeCreator(defaultQueryNodeClientCreator)
|
||||
mockShardMgr := shardclient.NewMockShardClientManager(t)
|
||||
mockShardMgr.EXPECT().SetClientCreatorFunc(mock.Anything).Return().Maybe()
|
||||
proxy.shardMgr = mockShardMgr
|
||||
|
||||
proxy.SetQueryNodeCreator(shardclient.DefaultQueryNodeClientCreator)
|
||||
log.Info("Proxy set query coordinator client")
|
||||
|
||||
proxy.UpdateStateCode(commonpb.StateCode_Initializing)
|
||||
|
||||
@ -404,13 +404,14 @@ func (op *requeryOperator) requery(ctx context.Context, span trace.Span, ids *sc
|
||||
PartitionIDs: op.partitionIDs, // use search partitionIDs
|
||||
ConsistencyLevel: op.consistencyLevel,
|
||||
},
|
||||
request: queryReq,
|
||||
plan: plan,
|
||||
mixCoord: op.node.(*Proxy).mixCoord,
|
||||
lb: op.node.(*Proxy).lbPolicy,
|
||||
channelsMvcc: channelsMvcc,
|
||||
fastSkip: true,
|
||||
reQuery: true,
|
||||
request: queryReq,
|
||||
plan: plan,
|
||||
mixCoord: op.node.(*Proxy).mixCoord,
|
||||
lb: op.node.(*Proxy).lbPolicy,
|
||||
shardclientMgr: op.node.(*Proxy).shardMgr,
|
||||
channelsMvcc: channelsMvcc,
|
||||
fastSkip: true,
|
||||
reQuery: true,
|
||||
}
|
||||
queryResult, storageCost, err := op.node.(*Proxy).query(op.traceCtx, qt, span)
|
||||
if err != nil {
|
||||
|
||||
@ -25,7 +25,7 @@ func TestNewInterceptor(t *testing.T) {
|
||||
mixCoord := mocks.NewMockMixCoordClient(t)
|
||||
mixCoord.On("DescribeCollection", mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe()
|
||||
var err error
|
||||
globalMetaCache, err = NewMetaCache(mixCoord, nil)
|
||||
globalMetaCache, err = NewMetaCache(mixCoord)
|
||||
assert.NoError(t, err)
|
||||
interceptor, err := NewInterceptor[*milvuspb.DescribeCollectionRequest, *milvuspb.DescribeCollectionResponse](node, "DescribeCollection")
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -1,160 +0,0 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func TestShardClientMgr(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
nodeInfo := nodeInfo{
|
||||
nodeID: 1,
|
||||
}
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil)
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
mgr.SetClientCreatorFunc(creator)
|
||||
_, err := mgr.GetClient(ctx, nodeInfo)
|
||||
assert.Nil(t, err)
|
||||
|
||||
mgr.Close()
|
||||
assert.Equal(t, mgr.clients.Len(), 0)
|
||||
}
|
||||
|
||||
func TestShardClient(t *testing.T) {
|
||||
nodeInfo := nodeInfo{
|
||||
nodeID: 1,
|
||||
}
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
shardClient := newShardClient(nodeInfo, creator, 3*time.Second)
|
||||
assert.Equal(t, len(shardClient.clients), 0)
|
||||
assert.Equal(t, false, shardClient.initialized.Load())
|
||||
assert.Equal(t, false, shardClient.isClosed)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := shardClient.getClient(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
|
||||
|
||||
// test close
|
||||
closed := shardClient.Close(false)
|
||||
assert.False(t, closed)
|
||||
closed = shardClient.Close(true)
|
||||
assert.True(t, closed)
|
||||
}
|
||||
|
||||
func TestPurgeClient(t *testing.T) {
|
||||
node := nodeInfo{
|
||||
nodeID: 1,
|
||||
}
|
||||
|
||||
returnEmptyResult := atomic.NewBool(false)
|
||||
|
||||
cache := NewMockCache(t)
|
||||
cache.EXPECT().ListShardLocation().RunAndReturn(func() map[int64]nodeInfo {
|
||||
if returnEmptyResult.Load() {
|
||||
return map[int64]nodeInfo{}
|
||||
}
|
||||
return map[int64]nodeInfo{
|
||||
1: node,
|
||||
}
|
||||
})
|
||||
globalMetaCache = cache
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
s := &shardClientMgrImpl{
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: creator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: 1 * time.Second,
|
||||
expiredDuration: 3 * time.Second,
|
||||
}
|
||||
|
||||
go s.PurgeClient()
|
||||
defer s.Close()
|
||||
_, err := s.GetClient(context.Background(), node)
|
||||
assert.Nil(t, err)
|
||||
qnClient, ok := s.clients.Get(1)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, qnClient.lastActiveTs.Load() > 0)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// expected client should not been purged before expiredDuration
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds())
|
||||
|
||||
_, err = s.GetClient(context.Background(), node)
|
||||
assert.Nil(t, err)
|
||||
time.Sleep(2 * time.Second)
|
||||
// GetClient should refresh lastActiveTs, expected client should not be purged
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds())
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// client reach the expiredDuration, expected client should not be purged
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds())
|
||||
|
||||
returnEmptyResult.Store(true)
|
||||
time.Sleep(2 * time.Second)
|
||||
// remove client from shard location, expected client should be purged
|
||||
assert.Equal(t, s.clients.Len(), 0)
|
||||
}
|
||||
|
||||
func BenchmarkShardClientMgr(b *testing.B) {
|
||||
node := nodeInfo{
|
||||
nodeID: 1,
|
||||
}
|
||||
cache := NewMockCache(b)
|
||||
cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{
|
||||
1: node,
|
||||
}).Maybe()
|
||||
globalMetaCache = cache
|
||||
qn := mocks.NewMockQueryNodeClient(b)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
s := &shardClientMgrImpl{
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: creator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: 1 * time.Second,
|
||||
expiredDuration: 10 * time.Second,
|
||||
}
|
||||
go s.PurgeClient()
|
||||
defer s.Close()
|
||||
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := s.GetClient(context.Background(), node)
|
||||
assert.Nil(b, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
8
internal/proxy/shardclient/OWNERS
Normal file
8
internal/proxy/shardclient/OWNERS
Normal file
@ -0,0 +1,8 @@
|
||||
# order by contributions
|
||||
reviewers:
|
||||
- weiliu1031
|
||||
- congqixia
|
||||
- czs007
|
||||
|
||||
approvers:
|
||||
- maintainers
|
||||
351
internal/proxy/shardclient/README.md
Normal file
351
internal/proxy/shardclient/README.md
Normal file
@ -0,0 +1,351 @@
|
||||
# ShardClient Package
|
||||
|
||||
The `shardclient` package provides client-side connection management and load balancing for communicating with QueryNode shards in the Milvus distributed architecture. It manages QueryNode client connections, caches shard leader information, and implements intelligent request routing strategies.
|
||||
|
||||
## Overview
|
||||
|
||||
In Milvus, collections are divided into shards (channels), and each shard has multiple replicas distributed across different QueryNodes for high availability and load balancing. The `shardclient` package is responsible for:
|
||||
|
||||
1. **Connection Management**: Maintaining a pool of gRPC connections to QueryNodes with automatic lifecycle management
|
||||
2. **Shard Leader Cache**: Caching the mapping of shards to their leader QueryNodes to reduce coordination overhead
|
||||
3. **Load Balancing**: Distributing requests across available QueryNode replicas using configurable policies
|
||||
4. **Fault Tolerance**: Automatic retry and failover when QueryNodes become unavailable
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Proxy Layer │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────┐ │
|
||||
│ │ ShardClientMgr │ │
|
||||
│ │ • Shard leader cache (database → collection → shards) │
|
||||
│ │ • QueryNode client pool management │
|
||||
│ │ • Client lifecycle (init, purge, close) │
|
||||
│ └───────────────────────┬──────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────▼──────────────────────────────┐ │
|
||||
│ │ LBPolicy │ │
|
||||
│ │ • Execute workload on collection/channels │ │
|
||||
│ │ • Retry logic with replica failover │ │
|
||||
│ │ • Node selection via balancer │ │
|
||||
│ └───────────────────────┬──────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌────────────────┴────────────────┐ │
|
||||
│ │ │ │
|
||||
│ ┌──────▼────────┐ ┌─────────▼──────────┐ │
|
||||
│ │ RoundRobin │ │ LookAsideBalancer │ │
|
||||
│ │ Balancer │ │ • Cost-based │ │
|
||||
│ │ │ │ • Health check │ │
|
||||
│ └───────────────┘ └────────────────────┘ │
|
||||
│ │ │
|
||||
│ ┌───────────────────────▼──────────────────────────────┐ │
|
||||
│ │ shardClient (per QueryNode) │ │
|
||||
│ │ • Connection pool (configurable size) │ │
|
||||
│ │ • Round-robin client selection │ │
|
||||
│ │ • Lazy initialization and expiration │ │
|
||||
│ └──────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────┬────────────────────────────────────────┘
|
||||
│ gRPC
|
||||
┌───────────────┴───────────────┐
|
||||
│ │
|
||||
┌─────▼─────┐ ┌──────▼──────┐
|
||||
│ QueryNode │ │ QueryNode │
|
||||
│ (1) │ │ (2) │
|
||||
└───────────┘ └─────────────┘
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. ShardClientMgr
|
||||
|
||||
The central manager for QueryNode client connections and shard leader information.
|
||||
|
||||
**File**: `manager.go`
|
||||
|
||||
**Key Responsibilities**:
|
||||
- Cache shard leader mappings from QueryCoord (`database → collectionName → channel → []nodeInfo`)
|
||||
- Manage `shardClient` instances for each QueryNode
|
||||
- Automatically purge expired clients (default: 60 minutes of inactivity)
|
||||
- Invalidate cache when shard leaders change
|
||||
|
||||
**Interface**:
|
||||
```go
|
||||
type ShardClientMgr interface {
|
||||
GetShard(ctx context.Context, withCache bool, database, collectionName string,
|
||||
collectionID int64, channel string) ([]nodeInfo, error)
|
||||
GetShardLeaderList(ctx context.Context, database, collectionName string,
|
||||
collectionID int64, withCache bool) ([]string, error)
|
||||
DeprecateShardCache(database, collectionName string)
|
||||
InvalidateShardLeaderCache(collections []int64)
|
||||
GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error)
|
||||
Start()
|
||||
Close()
|
||||
}
|
||||
```
|
||||
|
||||
**Configuration**:
|
||||
- `purgeInterval`: Interval for checking expired clients (default: 600s)
|
||||
- `expiredDuration`: Time after which inactive clients are purged (default: 60min)
|
||||
|
||||
### 2. shardClient
|
||||
|
||||
Manages a connection pool to a single QueryNode.
|
||||
|
||||
**File**: `shard_client.go`
|
||||
|
||||
**Features**:
|
||||
- **Lazy initialization**: Connections are created on first use
|
||||
- **Connection pooling**: Configurable pool size (`ProxyCfg.QueryNodePoolingSize`, default: 1)
|
||||
- **Round-robin selection**: Distributes requests across pool connections
|
||||
- **Expiration tracking**: Tracks last active time for automatic cleanup
|
||||
- **Thread-safe**: Safe for concurrent access
|
||||
|
||||
**Lifecycle**:
|
||||
1. Created when first request needs a QueryNode
|
||||
2. Initializes connection pool on first `getClient()` call
|
||||
3. Tracks `lastActiveTs` on each use
|
||||
4. Closed by manager if expired or during shutdown
|
||||
|
||||
### 3. LBPolicy
|
||||
|
||||
Executes workloads on collections/channels with retry and failover logic.
|
||||
|
||||
**File**: `lb_policy.go`
|
||||
|
||||
**Key Methods**:
|
||||
|
||||
- **`Execute(ctx, CollectionWorkLoad)`**: Execute workload in parallel across all shards
|
||||
- **`ExecuteOneChannel(ctx, CollectionWorkLoad)`**: Execute workload on any single shard (for lightweight operations)
|
||||
- **`ExecuteWithRetry(ctx, ChannelWorkload)`**: Execute on specific channel with retry on different replicas
|
||||
|
||||
**Retry Strategy**:
|
||||
- Retry up to `max(retryOnReplica, len(shardLeaders))` times
|
||||
- Maintain `excludeNodes` set to avoid retrying failed nodes
|
||||
- Refresh shard leader cache if initial attempt fails
|
||||
- Clear `excludeNodes` if all replicas exhausted
|
||||
|
||||
**Workload Types**:
|
||||
```go
|
||||
type ChannelWorkload struct {
|
||||
Db string
|
||||
CollectionName string
|
||||
CollectionID int64
|
||||
Channel string
|
||||
Nq int64 // Number of queries
|
||||
Exec ExecuteFunc // Actual work to execute
|
||||
}
|
||||
|
||||
type ExecuteFunc func(context.Context, UniqueID, types.QueryNodeClient, string) error
|
||||
```
|
||||
|
||||
### 4. Load Balancers
|
||||
|
||||
Two strategies for selecting QueryNode replicas:
|
||||
|
||||
#### RoundRobinBalancer
|
||||
|
||||
**File**: `roundrobin_balancer.go`
|
||||
|
||||
Simple round-robin selection across available nodes. No state tracking, minimal overhead.
|
||||
|
||||
**Use case**: Uniform workload distribution when all nodes have similar capacity
|
||||
|
||||
#### LookAsideBalancer
|
||||
|
||||
**File**: `look_aside_balancer.go`
|
||||
|
||||
Cost-aware load balancer that considers QueryNode workload and health.
|
||||
|
||||
**Features**:
|
||||
- **Cost metrics tracking**: Caches `CostAggregation` (response time, service time, total NQ) from QueryNodes
|
||||
- **Workload score calculation**: Uses power-of-3 formula to prefer lightly loaded nodes:
|
||||
```
|
||||
score = executeSpeed + (1 + totalNQ + executingNQ)³ × serviceTime
|
||||
```
|
||||
- **Periodic health checks**: Monitors QueryNode health via `GetComponentStates` RPC
|
||||
- **Unavailable node handling**: Marks nodes unreachable after consecutive health check failures
|
||||
- **Adaptive behavior**: Falls back to round-robin when workload difference is small
|
||||
|
||||
**Configuration Parameters**:
|
||||
- `ProxyCfg.CostMetricsExpireTime`: How long to trust cached cost metrics (default: varies)
|
||||
- `ProxyCfg.CheckWorkloadRequestNum`: Check workload every N requests (default: varies)
|
||||
- `ProxyCfg.WorkloadToleranceFactor`: Tolerance for workload difference before preferring lighter node
|
||||
- `ProxyCfg.CheckQueryNodeHealthInterval`: Interval for health checks
|
||||
- `ProxyCfg.HealthCheckTimeout`: Timeout for health check RPC
|
||||
- `ProxyCfg.RetryTimesOnHealthCheck`: Failures before marking node unreachable
|
||||
|
||||
**Selection Strategy**:
|
||||
```
|
||||
if (requestCount % CheckWorkloadRequestNum == 0) {
|
||||
// Cost-aware selection
|
||||
select node with minimum workload score
|
||||
if (maxScore - minScore) / minScore <= WorkloadToleranceFactor {
|
||||
fall back to round-robin
|
||||
}
|
||||
} else {
|
||||
// Fast path: round-robin
|
||||
select next available node
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Key configuration parameters from `paramtable`:
|
||||
|
||||
| Parameter | Path | Description | Default |
|
||||
|-----------|------|-------------|---------|
|
||||
| QueryNodePoolingSize | `ProxyCfg.QueryNodePoolingSize` | Size of connection pool per QueryNode | 1 |
|
||||
| RetryTimesOnReplica | `ProxyCfg.RetryTimesOnReplica` | Max retry times on replica failures | varies |
|
||||
| ReplicaSelectionPolicy | `ProxyCfg.ReplicaSelectionPolicy` | Load balancing policy: `round_robin` or `look_aside` | `look_aside` |
|
||||
| CostMetricsExpireTime | `ProxyCfg.CostMetricsExpireTime` | Expiration time for cost metrics cache | varies |
|
||||
| CheckWorkloadRequestNum | `ProxyCfg.CheckWorkloadRequestNum` | Frequency of workload-aware selection | varies |
|
||||
| WorkloadToleranceFactor | `ProxyCfg.WorkloadToleranceFactor` | Tolerance for workload differences | varies |
|
||||
| CheckQueryNodeHealthInterval | `ProxyCfg.CheckQueryNodeHealthInterval` | Health check interval | varies |
|
||||
| HealthCheckTimeout | `ProxyCfg.HealthCheckTimeout` | Health check RPC timeout | varies |
|
||||
|
||||
## Usage Example
|
||||
|
||||
```go
|
||||
import (
|
||||
"context"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
)
|
||||
|
||||
// 1. Create ShardClientMgr with MixCoord client
|
||||
mgr := shardclient.NewShardClientMgr(mixCoordClient)
|
||||
mgr.Start() // Start background purge goroutine
|
||||
defer mgr.Close()
|
||||
|
||||
// 2. Create LBPolicy
|
||||
policy := shardclient.NewLBPolicyImpl(mgr)
|
||||
policy.Start(ctx) // Start load balancer (health checks, etc.)
|
||||
defer policy.Close()
|
||||
|
||||
// 3. Execute collection workload (e.g., search/query)
|
||||
workload := shardclient.CollectionWorkLoad{
|
||||
Db: "default",
|
||||
CollectionName: "my_collection",
|
||||
CollectionID: 12345,
|
||||
Nq: 100, // Number of queries
|
||||
Exec: func(ctx context.Context, nodeID int64, client types.QueryNodeClient, channel string) error {
|
||||
// Perform actual work (search, query, etc.)
|
||||
req := &querypb.SearchRequest{/* ... */}
|
||||
resp, err := client.Search(ctx, req)
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
// Execute on all channels in parallel
|
||||
err := policy.Execute(ctx, workload)
|
||||
|
||||
// Or execute on any single channel (for lightweight ops)
|
||||
err := policy.ExecuteOneChannel(ctx, workload)
|
||||
```
|
||||
|
||||
## Cache Management
|
||||
|
||||
### Shard Leader Cache
|
||||
|
||||
The shard leader cache stores the mapping of shards to their leader QueryNodes:
|
||||
|
||||
```
|
||||
database → collectionName → shardLeaders {
|
||||
collectionID: int64
|
||||
shardLeaders: map[channel][]nodeInfo
|
||||
}
|
||||
```
|
||||
|
||||
**Cache Operations**:
|
||||
- **Hit**: When cached shard leaders are used (tracked via `ProxyCacheStatsCounter`)
|
||||
- **Miss**: When cache lookup fails, triggers RPC to QueryCoord via `GetShardLeaders`
|
||||
- **Invalidation**:
|
||||
- `DeprecateShardCache(db, collection)`: Remove specific collection
|
||||
- `InvalidateShardLeaderCache(collectionIDs)`: Remove collections by ID (called on shard leader changes)
|
||||
- `RemoveDatabase(db)`: Remove entire database
|
||||
|
||||
### Client Purging
|
||||
|
||||
The `ShardClientMgr` periodically purges unused clients:
|
||||
|
||||
1. Every `purgeInterval` (default: 600s), iterate all cached clients
|
||||
2. Check if client is still a shard leader (via `ListShardLocation()`)
|
||||
3. If not a leader and expired (`lastActiveTs` > `expiredDuration`), close and remove
|
||||
4. This prevents connection leaks when QueryNodes are removed or shards rebalance
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Common Errors
|
||||
|
||||
- **`errClosed`**: Client is closed (returned when accessing closed `shardClient`)
|
||||
- **`merr.ErrChannelNotAvailable`**: No available shard leaders for channel
|
||||
- **`merr.ErrNodeNotAvailable`**: Selected node is not available
|
||||
- **`merr.ErrCollectionNotLoaded`**: Collection is not loaded in QueryNodes
|
||||
- **`merr.ErrServiceUnavailable`**: All available nodes are unreachable
|
||||
|
||||
### Retry Logic
|
||||
|
||||
Retry is handled at multiple levels:
|
||||
|
||||
1. **LBPolicy level**:
|
||||
- Retries on different replicas when request fails
|
||||
- Refreshes shard leader cache on failure
|
||||
- Respects context cancellation
|
||||
|
||||
2. **Balancer level**:
|
||||
- Tracks failed nodes and excludes them from selection
|
||||
- Health checks recover nodes when they come back online
|
||||
|
||||
3. **gRPC level**:
|
||||
- Connection-level retries handled by gRPC layer
|
||||
|
||||
## Metrics
|
||||
|
||||
The package exports several metrics:
|
||||
|
||||
- `ProxyCacheStatsCounter`: Shard leader cache hit/miss statistics
|
||||
- Labels: `nodeID`, `method` (GetShard/GetShardLeaderList), `status` (hit/miss)
|
||||
- `ProxyUpdateCacheLatency`: Latency of updating shard leader cache
|
||||
- Labels: `nodeID`, `method`
|
||||
|
||||
## Testing
|
||||
|
||||
The package includes extensive test coverage:
|
||||
|
||||
- `shard_client_test.go`: Tests for connection pool management
|
||||
- `manager_test.go`: Tests for cache management and client lifecycle
|
||||
- `lb_policy_test.go`: Tests for retry logic and workload execution
|
||||
- `roundrobin_balancer_test.go`: Tests for round-robin selection
|
||||
- `look_aside_balancer_test.go`: Tests for cost-aware selection and health checks
|
||||
|
||||
**Mock interfaces** (via mockery):
|
||||
- `mock_shardclient_manager.go`: Mock `ShardClientMgr`
|
||||
- `mock_lb_policy.go`: Mock `LBPolicy`
|
||||
- `mock_lb_balancer.go`: Mock `LBBalancer`
|
||||
|
||||
## Thread Safety
|
||||
|
||||
All components are designed for concurrent access:
|
||||
|
||||
- `shardClientMgrImpl`: Uses `sync.RWMutex` for cache, `typeutil.ConcurrentMap` for clients
|
||||
- `shardClient`: Uses `sync.RWMutex` and atomic operations
|
||||
- `LookAsideBalancer`: Uses `typeutil.ConcurrentMap` for all mutable state
|
||||
- `RoundRobinBalancer`: Uses `atomic.Int64` for index
|
||||
|
||||
## Related Components
|
||||
|
||||
- **Proxy** (`internal/proxy/`): Uses `shardclient` to route search/query requests to QueryNodes
|
||||
- **QueryCoord** (`internal/querycoordv2/`): Provides shard leader information via `GetShardLeaders` RPC
|
||||
- **QueryNode** (`internal/querynodev2/`): Receives and processes requests routed by `shardclient`
|
||||
- **Registry** (`internal/registry/`): Provides client creation functions for gRPC connections
|
||||
|
||||
## Future Improvements
|
||||
|
||||
Potential areas for enhancement:
|
||||
|
||||
1. **Adaptive pooling**: Dynamically adjust connection pool size based on load
|
||||
2. **Circuit breaker**: Add circuit breaker pattern for consistently failing nodes
|
||||
3. **Advanced metrics**: Export more detailed metrics (per-node latency, error rates, etc.)
|
||||
4. **Smart caching**: Use TTL-based cache expiration instead of invalidation-only
|
||||
5. **Connection warming**: Pre-establish connections to known QueryNodes
|
||||
@ -14,7 +14,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -23,7 +23,7 @@ import (
|
||||
)
|
||||
|
||||
type LBBalancer interface {
|
||||
RegisterNodeInfo(nodeInfos []nodeInfo)
|
||||
RegisterNodeInfo(nodeInfos []NodeInfo)
|
||||
SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error)
|
||||
CancelWorkload(node int64, nq int64)
|
||||
UpdateCostMetrics(node int64, cost *internalpb.CostAggregation)
|
||||
@ -13,7 +13,8 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package proxy
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -30,27 +31,28 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/retry"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type executeFunc func(context.Context, UniqueID, types.QueryNodeClient, string) error
|
||||
type ExecuteFunc func(context.Context, UniqueID, types.QueryNodeClient, string) error
|
||||
|
||||
type ChannelWorkload struct {
|
||||
db string
|
||||
collectionName string
|
||||
collectionID int64
|
||||
channel string
|
||||
nq int64
|
||||
exec executeFunc
|
||||
Db string
|
||||
CollectionName string
|
||||
CollectionID int64
|
||||
Channel string
|
||||
Nq int64
|
||||
Exec ExecuteFunc
|
||||
}
|
||||
|
||||
type CollectionWorkLoad struct {
|
||||
db string
|
||||
collectionName string
|
||||
collectionID int64
|
||||
nq int64
|
||||
exec executeFunc
|
||||
Db string
|
||||
CollectionName string
|
||||
CollectionID int64
|
||||
Nq int64
|
||||
Exec ExecuteFunc
|
||||
}
|
||||
|
||||
type LBPolicy interface {
|
||||
@ -69,12 +71,12 @@ const (
|
||||
|
||||
type LBPolicyImpl struct {
|
||||
getBalancer func() LBBalancer
|
||||
clientMgr shardClientMgr
|
||||
clientMgr ShardClientMgr
|
||||
balancerMap map[string]LBBalancer
|
||||
retryOnReplica int
|
||||
}
|
||||
|
||||
func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
|
||||
func NewLBPolicyImpl(clientMgr ShardClientMgr) *LBPolicyImpl {
|
||||
balancerMap := make(map[string]LBBalancer)
|
||||
balancerMap[LookAside] = NewLookAsideBalancer(clientMgr)
|
||||
balancerMap[RoundRobin] = NewRoundRobinBalancer()
|
||||
@ -87,7 +89,7 @@ func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
|
||||
return balancerMap[balancePolicy]
|
||||
}
|
||||
|
||||
retryOnReplica := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt()
|
||||
retryOnReplica := paramtable.Get().ProxyCfg.RetryTimesOnReplica.GetAsInt()
|
||||
|
||||
return &LBPolicyImpl{
|
||||
getBalancer: getBalancer,
|
||||
@ -105,11 +107,11 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) {
|
||||
|
||||
// GetShard will retry until ctx done, except the collection is not loaded.
|
||||
// return all replicas of shard from cache if withCache is true, otherwise return shard leaders from coord.
|
||||
func (lb *LBPolicyImpl) GetShard(ctx context.Context, dbName string, collName string, collectionID int64, channel string, withCache bool) ([]nodeInfo, error) {
|
||||
var shardLeaders []nodeInfo
|
||||
func (lb *LBPolicyImpl) GetShard(ctx context.Context, dbName string, collName string, collectionID int64, channel string, withCache bool) ([]NodeInfo, error) {
|
||||
var shardLeaders []NodeInfo
|
||||
err := retry.Handle(ctx, func() (bool, error) {
|
||||
var err error
|
||||
shardLeaders, err = globalMetaCache.GetShard(ctx, withCache, dbName, collName, collectionID, channel)
|
||||
shardLeaders, err = lb.clientMgr.GetShard(ctx, withCache, dbName, collName, collectionID, channel)
|
||||
return !errors.Is(err, merr.ErrCollectionNotLoaded), err
|
||||
})
|
||||
return shardLeaders, err
|
||||
@ -121,25 +123,25 @@ func (lb *LBPolicyImpl) GetShardLeaderList(ctx context.Context, dbName string, c
|
||||
var ret []string
|
||||
err := retry.Handle(ctx, func() (bool, error) {
|
||||
var err error
|
||||
ret, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collName, collectionID, withCache)
|
||||
ret, err = lb.clientMgr.GetShardLeaderList(ctx, dbName, collName, collectionID, withCache)
|
||||
return !errors.Is(err, merr.ErrCollectionNotLoaded), err
|
||||
})
|
||||
return ret, err
|
||||
}
|
||||
|
||||
// try to select the best node from the available nodes
|
||||
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes *typeutil.UniqueSet) (nodeInfo, error) {
|
||||
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes *typeutil.UniqueSet) (NodeInfo, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64("collectionID", workload.CollectionID),
|
||||
zap.String("channelName", workload.Channel),
|
||||
)
|
||||
// Select node using specified nodes
|
||||
trySelectNode := func(withCache bool) (nodeInfo, error) {
|
||||
shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, withCache)
|
||||
trySelectNode := func(withCache bool) (NodeInfo, error) {
|
||||
shardLeaders, err := lb.GetShard(ctx, workload.Db, workload.CollectionName, workload.CollectionID, workload.Channel, withCache)
|
||||
if err != nil {
|
||||
log.Warn("failed to get shard delegator",
|
||||
zap.Error(err))
|
||||
return nodeInfo{}, err
|
||||
return NodeInfo{}, err
|
||||
}
|
||||
|
||||
// if all available delegator has been excluded even after refresh shard leader cache
|
||||
@ -147,7 +149,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
|
||||
if !withCache && len(shardLeaders) > 0 && len(shardLeaders) <= excludeNodes.Len() {
|
||||
allReplicaExcluded := true
|
||||
for _, node := range shardLeaders {
|
||||
if !excludeNodes.Contain(node.nodeID) {
|
||||
if !excludeNodes.Contain(node.NodeID) {
|
||||
allReplicaExcluded = false
|
||||
break
|
||||
}
|
||||
@ -158,14 +160,14 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
|
||||
}
|
||||
}
|
||||
|
||||
candidateNodes := make(map[int64]nodeInfo)
|
||||
serviceableNodes := make(map[int64]nodeInfo)
|
||||
candidateNodes := make(map[int64]NodeInfo)
|
||||
serviceableNodes := make(map[int64]NodeInfo)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
candidatesInStr := lo.Map(shardLeaders, func(node nodeInfo, _ int) string {
|
||||
candidatesInStr := lo.Map(shardLeaders, func(node NodeInfo, _ int) string {
|
||||
return node.String()
|
||||
})
|
||||
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node nodeInfo, _ int) string {
|
||||
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node NodeInfo, _ int) string {
|
||||
return node.String()
|
||||
})
|
||||
log.Warn("failed to select shard",
|
||||
@ -178,33 +180,33 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
|
||||
|
||||
// Filter nodes based on excludeNodes
|
||||
for _, node := range shardLeaders {
|
||||
if !excludeNodes.Contain(node.nodeID) {
|
||||
if node.serviceable {
|
||||
serviceableNodes[node.nodeID] = node
|
||||
if !excludeNodes.Contain(node.NodeID) {
|
||||
if node.Serviceable {
|
||||
serviceableNodes[node.NodeID] = node
|
||||
}
|
||||
candidateNodes[node.nodeID] = node
|
||||
candidateNodes[node.NodeID] = node
|
||||
}
|
||||
}
|
||||
if len(candidateNodes) == 0 {
|
||||
err = merr.WrapErrChannelNotAvailable(workload.channel, "no available shard leaders")
|
||||
return nodeInfo{}, err
|
||||
err = merr.WrapErrChannelNotAvailable(workload.Channel, "no available shard leaders")
|
||||
return NodeInfo{}, err
|
||||
}
|
||||
|
||||
balancer.RegisterNodeInfo(lo.Values(candidateNodes))
|
||||
// prefer serviceable nodes
|
||||
var targetNodeID int64
|
||||
if len(serviceableNodes) > 0 {
|
||||
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(serviceableNodes), workload.nq)
|
||||
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(serviceableNodes), workload.Nq)
|
||||
} else {
|
||||
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(candidateNodes), workload.nq)
|
||||
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(candidateNodes), workload.Nq)
|
||||
}
|
||||
if err != nil {
|
||||
return nodeInfo{}, err
|
||||
return NodeInfo{}, err
|
||||
}
|
||||
|
||||
if _, ok := candidateNodes[targetNodeID]; !ok {
|
||||
err = merr.WrapErrNodeNotAvailable(targetNodeID)
|
||||
return nodeInfo{}, err
|
||||
return NodeInfo{}, err
|
||||
}
|
||||
|
||||
return candidateNodes[targetNodeID], nil
|
||||
@ -218,7 +220,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
|
||||
withShardLeaderCache = false
|
||||
targetNode, err = trySelectNode(withShardLeaderCache)
|
||||
if err != nil {
|
||||
return nodeInfo{}, err
|
||||
return NodeInfo{}, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -228,8 +230,8 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
|
||||
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
|
||||
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.String("channelName", workload.channel),
|
||||
zap.Int64("collectionID", workload.CollectionID),
|
||||
zap.String("channelName", workload.Channel),
|
||||
)
|
||||
var lastErr error
|
||||
excludeNodes := typeutil.NewUniqueSet()
|
||||
@ -238,7 +240,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
targetNode, err := lb.selectNode(ctx, balancer, workload, &excludeNodes)
|
||||
if err != nil {
|
||||
log.Warn("failed to select node for shard",
|
||||
zap.Int64("nodeID", targetNode.nodeID),
|
||||
zap.Int64("nodeID", targetNode.NodeID),
|
||||
zap.Int64s("excluded", excludeNodes.Collect()),
|
||||
zap.Error(err),
|
||||
)
|
||||
@ -248,33 +250,33 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
return true, err
|
||||
}
|
||||
// cancel work load which assign to the target node
|
||||
defer balancer.CancelWorkload(targetNode.nodeID, workload.nq)
|
||||
defer balancer.CancelWorkload(targetNode.NodeID, workload.Nq)
|
||||
|
||||
client, err := lb.clientMgr.GetClient(ctx, targetNode)
|
||||
if err != nil {
|
||||
log.Warn("search/query channel failed, node not available",
|
||||
zap.Int64("nodeID", targetNode.nodeID),
|
||||
zap.Int64("nodeID", targetNode.NodeID),
|
||||
zap.Error(err))
|
||||
excludeNodes.Insert(targetNode.nodeID)
|
||||
excludeNodes.Insert(targetNode.NodeID)
|
||||
|
||||
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel)
|
||||
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.NodeID, workload.Channel)
|
||||
return true, lastErr
|
||||
}
|
||||
|
||||
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
|
||||
err = workload.Exec(ctx, targetNode.NodeID, client, workload.Channel)
|
||||
if err != nil {
|
||||
log.Warn("search/query channel failed",
|
||||
zap.Int64("nodeID", targetNode.nodeID),
|
||||
zap.Int64("nodeID", targetNode.NodeID),
|
||||
zap.Error(err))
|
||||
excludeNodes.Insert(targetNode.nodeID)
|
||||
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.nodeID, workload.channel)
|
||||
excludeNodes.Insert(targetNode.NodeID)
|
||||
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.NodeID, workload.Channel)
|
||||
return true, lastErr
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, true)
|
||||
shardLeaders, err := lb.GetShard(ctx, workload.Db, workload.CollectionName, workload.CollectionID, workload.Channel, true)
|
||||
if err != nil {
|
||||
log.Warn("failed to get shard leaders", zap.Error(err))
|
||||
return err
|
||||
@ -283,7 +285,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
err = retry.Handle(ctx, tryExecute, retry.Attempts(uint(retryTimes)))
|
||||
if err != nil {
|
||||
log.Warn("failed to execute",
|
||||
zap.String("channel", workload.channel),
|
||||
zap.String("channel", workload.Channel),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
@ -293,17 +295,17 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
|
||||
// Execute will execute collection workload in parallel
|
||||
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", workload.collectionID),
|
||||
zap.Int64("collectionID", workload.CollectionID),
|
||||
)
|
||||
channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true)
|
||||
channelList, err := lb.GetShardLeaderList(ctx, workload.Db, workload.CollectionName, workload.CollectionID, true)
|
||||
if err != nil {
|
||||
log.Warn("failed to get shards", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if len(channelList) == 0 {
|
||||
log.Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID))
|
||||
return merr.WrapErrCollectionNotLoaded(workload.collectionID)
|
||||
log.Info("no shard leaders found", zap.Int64("collectionID", workload.CollectionID))
|
||||
return merr.WrapErrCollectionNotLoaded(workload.CollectionID)
|
||||
}
|
||||
|
||||
wg, _ := errgroup.WithContext(ctx)
|
||||
@ -311,12 +313,12 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
|
||||
for _, channel := range channelList {
|
||||
wg.Go(func() error {
|
||||
return lb.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: workload.db,
|
||||
collectionName: workload.collectionName,
|
||||
collectionID: workload.collectionID,
|
||||
channel: channel,
|
||||
nq: workload.nq,
|
||||
exec: workload.exec,
|
||||
Db: workload.Db,
|
||||
CollectionName: workload.CollectionName,
|
||||
CollectionID: workload.CollectionID,
|
||||
Channel: channel,
|
||||
Nq: workload.Nq,
|
||||
Exec: workload.Exec,
|
||||
})
|
||||
})
|
||||
}
|
||||
@ -325,7 +327,7 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
|
||||
|
||||
// Execute will execute any one channel in collection workload
|
||||
func (lb *LBPolicyImpl) ExecuteOneChannel(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true)
|
||||
channelList, err := lb.GetShardLeaderList(ctx, workload.Db, workload.CollectionName, workload.CollectionID, true)
|
||||
if err != nil {
|
||||
log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
|
||||
return err
|
||||
@ -334,15 +336,15 @@ func (lb *LBPolicyImpl) ExecuteOneChannel(ctx context.Context, workload Collecti
|
||||
// let every request could retry at least twice, which could retry after update shard leader cache
|
||||
for _, channel := range channelList {
|
||||
return lb.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
db: workload.db,
|
||||
collectionName: workload.collectionName,
|
||||
collectionID: workload.collectionID,
|
||||
channel: channel,
|
||||
nq: workload.nq,
|
||||
exec: workload.exec,
|
||||
Db: workload.Db,
|
||||
CollectionName: workload.CollectionName,
|
||||
CollectionID: workload.CollectionID,
|
||||
Channel: channel,
|
||||
Nq: workload.Nq,
|
||||
Exec: workload.Exec,
|
||||
})
|
||||
}
|
||||
return fmt.Errorf("no acitvate sheard leader exist for collection: %s", workload.collectionName)
|
||||
return fmt.Errorf("no acitvate sheard leader exist for collection: %s", workload.CollectionName)
|
||||
}
|
||||
|
||||
func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
|
||||
664
internal/proxy/shardclient/lb_policy_test.go
Normal file
664
internal/proxy/shardclient/lb_policy_test.go
Normal file
@ -0,0 +1,664 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type LBPolicySuite struct {
|
||||
suite.Suite
|
||||
qn *mocks.MockQueryNodeClient
|
||||
|
||||
mgr *MockShardClientManager
|
||||
lbBalancer *MockLBBalancer
|
||||
lbPolicy *LBPolicyImpl
|
||||
|
||||
nodeIDs []int64
|
||||
nodes []NodeInfo
|
||||
channels []string
|
||||
|
||||
dbName string
|
||||
collectionName string
|
||||
collectionID int64
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) SetupSuite() {
|
||||
paramtable.Init()
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) SetupTest() {
|
||||
s.nodeIDs = make([]int64, 0)
|
||||
s.nodes = make([]NodeInfo, 0)
|
||||
for i := 1; i <= 5; i++ {
|
||||
s.nodeIDs = append(s.nodeIDs, int64(i))
|
||||
s.nodes = append(s.nodes, NodeInfo{
|
||||
NodeID: int64(i),
|
||||
Address: "localhost",
|
||||
Serviceable: true,
|
||||
})
|
||||
}
|
||||
s.channels = []string{"channel1", "channel2"}
|
||||
|
||||
s.qn = mocks.NewMockQueryNodeClient(s.T())
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
s.mgr = NewMockShardClientManager(s.T())
|
||||
s.lbBalancer = NewMockLBBalancer(s.T())
|
||||
s.lbBalancer.EXPECT().Start(mock.Anything).Maybe()
|
||||
s.lbBalancer.EXPECT().Close().Maybe()
|
||||
|
||||
s.lbPolicy = NewLBPolicyImpl(s.mgr)
|
||||
s.lbPolicy.Start(context.Background())
|
||||
s.lbPolicy.getBalancer = func() LBBalancer {
|
||||
return s.lbBalancer
|
||||
}
|
||||
|
||||
s.dbName = "test_lb_policy"
|
||||
s.collectionName = "test_lb_policy"
|
||||
s.collectionID = 100
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TearDownTest() {
|
||||
s.lbPolicy.Close()
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestSelectNode() {
|
||||
ctx := context.Background()
|
||||
|
||||
// test select node success
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(5), nil)
|
||||
excludeNodes := typeutil.NewUniqueSet()
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(5), targetNode.NodeID)
|
||||
|
||||
// test select node failed, then update shard leader cache and retry, expect success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
// First call with cache fails
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), errors.New("fake err")).Once()
|
||||
// Second call without cache succeeds
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(3), nil).Once()
|
||||
excludeNodes = typeutil.NewUniqueSet()
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), targetNode.NodeID)
|
||||
|
||||
// test select node always fails, expected failure
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), merr.ErrNodeNotAvailable)
|
||||
excludeNodes = typeutil.NewUniqueSet()
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.ErrorIs(err, merr.ErrNodeNotAvailable)
|
||||
|
||||
// test all nodes has been excluded, expected clear excludeNodes and try to select node again
|
||||
excludeNodes = typeutil.NewUniqueSet(s.nodeIDs...)
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), merr.ErrNodeNotAvailable)
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.ErrorIs(err, merr.ErrNodeNotAvailable)
|
||||
|
||||
// test get shard leaders failed, retry to select node failed
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(nil, merr.ErrCollectionNotLoaded).Once()
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(nil, merr.ErrCollectionNotLoaded).Once()
|
||||
excludeNodes = typeutil.NewUniqueSet()
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.ErrorIs(err, merr.ErrCollectionNotLoaded)
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestExecuteWithRetry() {
|
||||
ctx := context.Background()
|
||||
|
||||
// test execute success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test select node failed, expected error
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), merr.ErrNodeNotAvailable)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.ErrorIs(err, merr.ErrNodeNotAvailable)
|
||||
|
||||
// test get client failed, and retry failed, expected error
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error"))
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
|
||||
// test get client failed once, then retry success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Once()
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test exec failed, then retry success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
counter := 0
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
counter++
|
||||
if counter == 1 {
|
||||
return errors.New("fake error")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test exec timeout
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Once()
|
||||
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.DeadlineExceeded)
|
||||
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
_, err := qn.Search(ctx, nil)
|
||||
return err
|
||||
},
|
||||
})
|
||||
s.True(merr.IsCanceledOrTimeout(err))
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestExecuteOneChannel() {
|
||||
ctx := context.Background()
|
||||
mockErr := errors.New("mock error")
|
||||
|
||||
// test all channel success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test get shard leader failed
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(nil, mockErr)
|
||||
err = s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.ErrorIs(err, mockErr)
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestExecute() {
|
||||
ctx := context.Background()
|
||||
mockErr := errors.New("mock error")
|
||||
|
||||
// test all channel success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil).Maybe()
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// test some channel failed
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil)
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil).Maybe()
|
||||
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
|
||||
return availableNodes[0], nil
|
||||
})
|
||||
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
|
||||
counter := atomic.NewInt64(0)
|
||||
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
// succeed in first execute
|
||||
if counter.Add(1) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return mockErr
|
||||
},
|
||||
})
|
||||
s.Error(err)
|
||||
s.Equal(int64(6), counter.Load())
|
||||
|
||||
// test get shard leader failed
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(nil, mockErr)
|
||||
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Nq: 1,
|
||||
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
s.ErrorIs(err, mockErr)
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestUpdateCostMetrics() {
|
||||
s.lbBalancer.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything)
|
||||
s.lbPolicy.UpdateCostMetrics(1, &internalpb.CostAggregation{})
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestNewLBPolicy() {
|
||||
mgr := NewMockShardClientManager(s.T())
|
||||
policy := NewLBPolicyImpl(mgr)
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*shardclient.LookAsideBalancer")
|
||||
policy.Close()
|
||||
|
||||
params := paramtable.Get()
|
||||
|
||||
params.Save(params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin")
|
||||
policy = NewLBPolicyImpl(mgr)
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*shardclient.RoundRobinBalancer")
|
||||
policy.Close()
|
||||
|
||||
params.Save(params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside")
|
||||
policy = NewLBPolicyImpl(mgr)
|
||||
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*shardclient.LookAsideBalancer")
|
||||
policy.Close()
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestGetShard() {
|
||||
ctx := context.Background()
|
||||
|
||||
// ErrCollectionNotLoaded is not retriable, expected to fail fast
|
||||
counter := atomic.NewInt64(0)
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).RunAndReturn(
|
||||
func(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]NodeInfo, error) {
|
||||
counter.Inc()
|
||||
return nil, merr.ErrCollectionNotLoaded
|
||||
})
|
||||
_, err := s.lbPolicy.GetShard(ctx, s.dbName, s.collectionName, s.collectionID, s.channels[0], true)
|
||||
s.Error(err)
|
||||
s.Equal(int64(1), counter.Load())
|
||||
|
||||
// Normal case - success
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
|
||||
shardLeaders, err := s.lbPolicy.GetShard(ctx, s.dbName, s.collectionName, s.collectionID, s.channels[0], true)
|
||||
s.NoError(err)
|
||||
s.Equal(len(s.nodes), len(shardLeaders))
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestSelectNodeEdgeCases() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test case 1: Empty shard leaders after refresh, should fail gracefully
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
|
||||
|
||||
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
|
||||
_, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.Error(err)
|
||||
|
||||
// Test case 2: Single replica scenario - exclude it, refresh shows same single replica, should clear and succeed
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
singleNode := []NodeInfo{{NodeID: 1, Address: "localhost:9000", Serviceable: true}}
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(singleNode, nil).Once()
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(singleNode, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude the single node
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(1), targetNode.NodeID)
|
||||
s.Equal(0, excludeNodes.Len()) // Should be cleared
|
||||
|
||||
// Test case 3: Mixed serviceable nodes - prefer serviceable over non-serviceable
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
mixedNodes := []NodeInfo{
|
||||
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
||||
{NodeID: 2, Address: "localhost:9001", Serviceable: false},
|
||||
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
||||
}
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(mixedNodes, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
// Should select from serviceable nodes only (node 3, since node 1 is excluded)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.MatchedBy(func(nodes []int64) bool {
|
||||
return len(nodes) == 1 && nodes[0] == 3 // Only node 3 is serviceable and not excluded
|
||||
}), mock.Anything).Return(int64(3), nil).Once()
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude node 1, node 3 should be available
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), targetNode.NodeID)
|
||||
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestGetShardLeaderList() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test normal scenario with cache
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
|
||||
channelList, err := s.lbPolicy.GetShardLeaderList(ctx, s.dbName, s.collectionName, s.collectionID, true)
|
||||
s.NoError(err)
|
||||
s.Equal(len(s.channels), len(channelList))
|
||||
s.Contains(channelList, s.channels[0])
|
||||
s.Contains(channelList, s.channels[1])
|
||||
|
||||
// Test without cache - should refresh from coordinator
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, false).Return(s.channels, nil)
|
||||
channelList, err = s.lbPolicy.GetShardLeaderList(ctx, s.dbName, s.collectionName, s.collectionID, false)
|
||||
s.NoError(err)
|
||||
s.Equal(len(s.channels), len(channelList))
|
||||
|
||||
// Test error case - collection not loaded
|
||||
counter := atomic.NewInt64(0)
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).RunAndReturn(
|
||||
func(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) {
|
||||
counter.Inc()
|
||||
return nil, merr.ErrCollectionNotLoaded
|
||||
})
|
||||
_, err = s.lbPolicy.GetShardLeaderList(ctx, s.dbName, s.collectionName, s.collectionID, true)
|
||||
s.Error(err)
|
||||
s.ErrorIs(err, merr.ErrCollectionNotLoaded)
|
||||
s.Equal(int64(1), counter.Load())
|
||||
}
|
||||
|
||||
func (s *LBPolicySuite) TestSelectNodeWithExcludeClearing() {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test exclude nodes clearing when all replicas are excluded after cache refresh
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
twoNodes := []NodeInfo{
|
||||
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
||||
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
||||
}
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(twoNodes, nil).Once()
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(twoNodes, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
|
||||
|
||||
excludeNodes := typeutil.NewUniqueSet(int64(1), int64(2)) // Exclude all available nodes
|
||||
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(int64(1), targetNode.NodeID)
|
||||
s.Equal(0, excludeNodes.Len()) // Should be cleared when all replicas were excluded
|
||||
|
||||
// Test exclude nodes NOT cleared when only partial replicas are excluded
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
threeNodes := []NodeInfo{
|
||||
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
||||
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
||||
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
||||
}
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(threeNodes, nil).Once()
|
||||
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
|
||||
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(2), nil).Once()
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Only exclude node 1
|
||||
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
|
||||
s.NoError(err)
|
||||
s.Equal(int64(2), targetNode.NodeID)
|
||||
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
|
||||
|
||||
// Test empty shard leaders scenario
|
||||
s.mgr.ExpectedCalls = nil
|
||||
s.lbBalancer.ExpectedCalls = nil
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
|
||||
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
|
||||
|
||||
excludeNodes = typeutil.NewUniqueSet(int64(1))
|
||||
_, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
|
||||
Db: s.dbName,
|
||||
CollectionName: s.collectionName,
|
||||
CollectionID: s.collectionID,
|
||||
Channel: s.channels[0],
|
||||
Nq: 1,
|
||||
}, &excludeNodes)
|
||||
|
||||
s.Error(err)
|
||||
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared for empty shard leaders
|
||||
}
|
||||
|
||||
func TestLBPolicySuite(t *testing.T) {
|
||||
suite.Run(t, new(LBPolicySuite))
|
||||
}
|
||||
@ -14,7 +14,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -31,6 +31,7 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/conc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
@ -42,9 +43,9 @@ type CostMetrics struct {
|
||||
}
|
||||
|
||||
type LookAsideBalancer struct {
|
||||
clientMgr shardClientMgr
|
||||
clientMgr ShardClientMgr
|
||||
|
||||
knownNodeInfos *typeutil.ConcurrentMap[int64, nodeInfo]
|
||||
knownNodeInfos *typeutil.ConcurrentMap[int64, NodeInfo]
|
||||
metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics]
|
||||
// query node id -> number of consecutive heartbeat failures
|
||||
failedHeartBeatCounter *typeutil.ConcurrentMap[int64, *atomic.Int64]
|
||||
@ -62,18 +63,18 @@ type LookAsideBalancer struct {
|
||||
workloadToleranceFactor float64
|
||||
}
|
||||
|
||||
func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer {
|
||||
func NewLookAsideBalancer(clientMgr ShardClientMgr) *LookAsideBalancer {
|
||||
balancer := &LookAsideBalancer{
|
||||
clientMgr: clientMgr,
|
||||
knownNodeInfos: typeutil.NewConcurrentMap[int64, nodeInfo](),
|
||||
knownNodeInfos: typeutil.NewConcurrentMap[int64, NodeInfo](),
|
||||
metricsMap: typeutil.NewConcurrentMap[int64, *CostMetrics](),
|
||||
failedHeartBeatCounter: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
|
||||
closeCh: make(chan struct{}),
|
||||
}
|
||||
|
||||
balancer.metricExpireInterval = Params.ProxyCfg.CostMetricsExpireTime.GetAsInt64()
|
||||
balancer.checkWorkloadRequestNum = Params.ProxyCfg.CheckWorkloadRequestNum.GetAsInt64()
|
||||
balancer.workloadToleranceFactor = Params.ProxyCfg.WorkloadToleranceFactor.GetAsFloat()
|
||||
balancer.metricExpireInterval = paramtable.Get().ProxyCfg.CostMetricsExpireTime.GetAsInt64()
|
||||
balancer.checkWorkloadRequestNum = paramtable.Get().ProxyCfg.CheckWorkloadRequestNum.GetAsInt64()
|
||||
balancer.workloadToleranceFactor = paramtable.Get().ProxyCfg.WorkloadToleranceFactor.GetAsFloat()
|
||||
|
||||
return balancer
|
||||
}
|
||||
@ -90,9 +91,9 @@ func (b *LookAsideBalancer) Close() {
|
||||
})
|
||||
}
|
||||
|
||||
func (b *LookAsideBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {
|
||||
func (b *LookAsideBalancer) RegisterNodeInfo(nodeInfos []NodeInfo) {
|
||||
for _, node := range nodeInfos {
|
||||
b.knownNodeInfos.Insert(node.nodeID, node)
|
||||
b.knownNodeInfos.Insert(node.NodeID, node)
|
||||
}
|
||||
}
|
||||
|
||||
@ -227,7 +228,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
|
||||
log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60)
|
||||
defer b.wg.Done()
|
||||
|
||||
checkHealthInterval := Params.ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond)
|
||||
checkHealthInterval := paramtable.Get().ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond)
|
||||
ticker := time.NewTicker(checkHealthInterval)
|
||||
defer ticker.Stop()
|
||||
log.Info("Start check query node health loop")
|
||||
@ -241,11 +242,11 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
|
||||
case <-ticker.C:
|
||||
var futures []*conc.Future[any]
|
||||
now := time.Now()
|
||||
b.knownNodeInfos.Range(func(node int64, info nodeInfo) bool {
|
||||
b.knownNodeInfos.Range(func(node int64, info NodeInfo) bool {
|
||||
futures = append(futures, pool.Submit(func() (any, error) {
|
||||
metrics, ok := b.metricsMap.Get(node)
|
||||
if !ok || now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() {
|
||||
checkTimeout := Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond)
|
||||
checkTimeout := paramtable.Get().ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), checkTimeout)
|
||||
defer cancel()
|
||||
|
||||
@ -301,14 +302,14 @@ func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) {
|
||||
zap.Int64("times", failures.Load()),
|
||||
zap.Error(err))
|
||||
|
||||
if failures.Load() < Params.ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() {
|
||||
if failures.Load() < paramtable.Get().ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() {
|
||||
return
|
||||
}
|
||||
|
||||
// if the total time of consecutive heartbeat failures reach the session.ttl, remove the offline query node
|
||||
limit := Params.CommonCfg.SessionTTL.GetAsDuration(time.Second).Seconds() /
|
||||
Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond).Seconds()
|
||||
if failures.Load() > Params.ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() && float64(failures.Load()) >= limit {
|
||||
limit := paramtable.Get().CommonCfg.SessionTTL.GetAsDuration(time.Second).Seconds() /
|
||||
paramtable.Get().ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond).Seconds()
|
||||
if failures.Load() > paramtable.Get().ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() && float64(failures.Load()) >= limit {
|
||||
log.Info("the heartbeat failures has reach it's upper limit, remove the query node",
|
||||
zap.Int64("nodeID", node))
|
||||
// stop the heartbeat
|
||||
@ -14,7 +14,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package proxy
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -32,6 +32,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
type LookAsideBalancerSuite struct {
|
||||
@ -308,12 +309,12 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
|
||||
},
|
||||
}, nil).Maybe()
|
||||
suite.clientMgr.ExpectedCalls = nil
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) {
|
||||
if ni.nodeID == 1 {
|
||||
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni NodeInfo) (types.QueryNodeClient, error) {
|
||||
if ni.NodeID == 1 {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
if ni.nodeID == 2 {
|
||||
if ni.NodeID == 2 {
|
||||
return qn2, nil
|
||||
}
|
||||
return nil, errors.New("unexpected node")
|
||||
@ -323,19 +324,19 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
|
||||
metrics1.ts.Store(time.Now().UnixMilli())
|
||||
metrics1.unavailable.Store(true)
|
||||
suite.balancer.metricsMap.Insert(1, metrics1)
|
||||
suite.balancer.RegisterNodeInfo([]nodeInfo{
|
||||
suite.balancer.RegisterNodeInfo([]NodeInfo{
|
||||
{
|
||||
nodeID: 1,
|
||||
NodeID: 1,
|
||||
},
|
||||
})
|
||||
metrics2 := &CostMetrics{}
|
||||
metrics2.ts.Store(time.Now().UnixMilli())
|
||||
metrics2.unavailable.Store(true)
|
||||
suite.balancer.metricsMap.Insert(2, metrics2)
|
||||
suite.balancer.knownNodeInfos.Insert(2, nodeInfo{})
|
||||
suite.balancer.RegisterNodeInfo([]nodeInfo{
|
||||
suite.balancer.knownNodeInfos.Insert(2, NodeInfo{})
|
||||
suite.balancer.RegisterNodeInfo([]NodeInfo{
|
||||
{
|
||||
nodeID: 2,
|
||||
NodeID: 2,
|
||||
},
|
||||
})
|
||||
suite.Eventually(func() bool {
|
||||
@ -363,9 +364,9 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() {
|
||||
metrics1.ts.Store(time.Now().UnixMilli())
|
||||
metrics1.unavailable.Store(true)
|
||||
suite.balancer.metricsMap.Insert(2, metrics1)
|
||||
suite.balancer.RegisterNodeInfo([]nodeInfo{
|
||||
suite.balancer.RegisterNodeInfo([]NodeInfo{
|
||||
{
|
||||
nodeID: 2,
|
||||
NodeID: 2,
|
||||
},
|
||||
})
|
||||
|
||||
@ -398,9 +399,9 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
|
||||
metrics1 := &CostMetrics{}
|
||||
metrics1.ts.Store(time.Now().UnixMilli())
|
||||
suite.balancer.metricsMap.Insert(3, metrics1)
|
||||
suite.balancer.RegisterNodeInfo([]nodeInfo{
|
||||
suite.balancer.RegisterNodeInfo([]NodeInfo{
|
||||
{
|
||||
nodeID: 3,
|
||||
NodeID: 3,
|
||||
},
|
||||
})
|
||||
suite.Eventually(func() bool {
|
||||
@ -415,8 +416,9 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
|
||||
}
|
||||
|
||||
func (suite *LookAsideBalancerSuite) TestNodeOffline() {
|
||||
Params.Save(Params.CommonCfg.SessionTTL.Key, "10")
|
||||
Params.Save(Params.ProxyCfg.HealthCheckTimeout.Key, "1000")
|
||||
params := paramtable.Get()
|
||||
params.Save(params.CommonCfg.SessionTTL.Key, "10")
|
||||
params.Save(params.ProxyCfg.HealthCheckTimeout.Key, "1000")
|
||||
// mock qn down for a while and then recover
|
||||
qn3 := mocks.NewMockQueryNodeClient(suite.T())
|
||||
suite.clientMgr.ExpectedCalls = nil
|
||||
@ -430,9 +432,9 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() {
|
||||
metrics1 := &CostMetrics{}
|
||||
metrics1.ts.Store(time.Now().UnixMilli())
|
||||
suite.balancer.metricsMap.Insert(3, metrics1)
|
||||
suite.balancer.RegisterNodeInfo([]nodeInfo{
|
||||
suite.balancer.RegisterNodeInfo([]NodeInfo{
|
||||
{
|
||||
nodeID: 3,
|
||||
NodeID: 3,
|
||||
},
|
||||
})
|
||||
suite.Eventually(func() bool {
|
||||
337
internal/proxy/shardclient/manager.go
Normal file
337
internal/proxy/shardclient/manager.go
Normal file
@ -0,0 +1,337 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/registry"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/metrics"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type ShardClientMgr interface {
|
||||
GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]NodeInfo, error)
|
||||
GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error)
|
||||
DeprecateShardCache(database, collectionName string)
|
||||
InvalidateShardLeaderCache(collections []int64)
|
||||
ListShardLocation() map[int64]NodeInfo
|
||||
RemoveDatabase(database string)
|
||||
|
||||
GetClient(ctx context.Context, nodeInfo NodeInfo) (types.QueryNodeClient, error)
|
||||
SetClientCreatorFunc(creator queryNodeCreatorFunc)
|
||||
|
||||
Start()
|
||||
Close()
|
||||
}
|
||||
|
||||
type shardClientMgrImpl struct {
|
||||
clients *typeutil.ConcurrentMap[UniqueID, *shardClient]
|
||||
clientCreator queryNodeCreatorFunc
|
||||
closeCh chan struct{}
|
||||
|
||||
purgeInterval time.Duration
|
||||
expiredDuration time.Duration
|
||||
|
||||
mixCoord types.MixCoordClient
|
||||
|
||||
leaderMut sync.RWMutex
|
||||
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPurgeInterval = 600 * time.Second
|
||||
defaultExpiredDuration = 60 * time.Minute
|
||||
)
|
||||
|
||||
// SessionOpt provides a way to set params in ShardClientMgr
|
||||
type shardClientMgrOpt func(s ShardClientMgr)
|
||||
|
||||
func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
|
||||
return func(s ShardClientMgr) { s.SetClientCreatorFunc(creator) }
|
||||
}
|
||||
|
||||
func DefaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return registry.GetInMemoryResolver().ResolveQueryNode(ctx, addr, nodeID)
|
||||
}
|
||||
|
||||
// NewShardClientMgr creates a new shardClientMgr
|
||||
func NewShardClientMgr(mixCoord types.MixCoordClient, options ...shardClientMgrOpt) *shardClientMgrImpl {
|
||||
s := &shardClientMgrImpl{
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: DefaultQueryNodeClientCreator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: defaultPurgeInterval,
|
||||
expiredDuration: defaultExpiredDuration,
|
||||
|
||||
collLeader: make(map[string]map[string]*shardLeaders),
|
||||
mixCoord: mixCoord,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
|
||||
c.clientCreator = creator
|
||||
}
|
||||
|
||||
func (m *shardClientMgrImpl) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]NodeInfo, error) {
|
||||
method := "GetShard"
|
||||
// check cache first
|
||||
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
|
||||
if cacheShardLeaders == nil || !withCache {
|
||||
// refresh shard leader cache
|
||||
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheShardLeaders = newShardLeaders
|
||||
}
|
||||
|
||||
return cacheShardLeaders.Get(channel), nil
|
||||
}
|
||||
|
||||
func (m *shardClientMgrImpl) GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) {
|
||||
method := "GetShardLeaderList"
|
||||
// check cache first
|
||||
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
|
||||
if cacheShardLeaders == nil || !withCache {
|
||||
// refresh shard leader cache
|
||||
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cacheShardLeaders = newShardLeaders
|
||||
}
|
||||
|
||||
return cacheShardLeaders.GetShardLeaderList(), nil
|
||||
}
|
||||
|
||||
func (m *shardClientMgrImpl) getCachedShardLeaders(database, collectionName, caller string) *shardLeaders {
|
||||
m.leaderMut.RLock()
|
||||
var cacheShardLeaders *shardLeaders
|
||||
db, ok := m.collLeader[database]
|
||||
if !ok {
|
||||
cacheShardLeaders = nil
|
||||
} else {
|
||||
cacheShardLeaders = db[collectionName]
|
||||
}
|
||||
m.leaderMut.RUnlock()
|
||||
|
||||
if cacheShardLeaders != nil {
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheHitLabel).Inc()
|
||||
} else {
|
||||
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheMissLabel).Inc()
|
||||
}
|
||||
|
||||
return cacheShardLeaders
|
||||
}
|
||||
|
||||
func (m *shardClientMgrImpl) updateShardLocationCache(ctx context.Context, database, collectionName string, collectionID int64) (*shardLeaders, error) {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.String("db", database),
|
||||
zap.String("collectionName", collectionName),
|
||||
zap.Int64("collectionID", collectionID))
|
||||
|
||||
method := "updateShardLocationCache"
|
||||
tr := timerecord.NewTimeRecorder(method)
|
||||
defer metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).
|
||||
Observe(float64(tr.ElapseSpan().Milliseconds()))
|
||||
|
||||
req := &querypb.GetShardLeadersRequest{
|
||||
Base: commonpbutil.NewMsgBase(
|
||||
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
|
||||
commonpbutil.WithSourceID(paramtable.GetNodeID()),
|
||||
),
|
||||
CollectionID: collectionID,
|
||||
WithUnserviceableShards: true,
|
||||
}
|
||||
resp, err := m.mixCoord.GetShardLeaders(ctx, req)
|
||||
if err := merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
|
||||
log.Error("failed to get shard locations",
|
||||
zap.Int64("collectionID", collectionID),
|
||||
zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
shards := parseShardLeaderList2QueryNode(resp.GetShards())
|
||||
|
||||
// convert shards map to string for logging
|
||||
if log.Logger.Level() == zap.DebugLevel {
|
||||
shardStr := make([]string, 0, len(shards))
|
||||
for channel, nodes := range shards {
|
||||
nodeStrs := make([]string, 0, len(nodes))
|
||||
for _, node := range nodes {
|
||||
nodeStrs = append(nodeStrs, node.String())
|
||||
}
|
||||
shardStr = append(shardStr, fmt.Sprintf("%s:[%s]", channel, strings.Join(nodeStrs, ", ")))
|
||||
}
|
||||
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
|
||||
}
|
||||
|
||||
newShardLeaders := &shardLeaders{
|
||||
collectionID: collectionID,
|
||||
shardLeaders: shards,
|
||||
idx: atomic.NewInt64(0),
|
||||
}
|
||||
|
||||
m.leaderMut.Lock()
|
||||
if _, ok := m.collLeader[database]; !ok {
|
||||
m.collLeader[database] = make(map[string]*shardLeaders)
|
||||
}
|
||||
m.collLeader[database][collectionName] = newShardLeaders
|
||||
m.leaderMut.Unlock()
|
||||
|
||||
return newShardLeaders, nil
|
||||
}
|
||||
|
||||
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]NodeInfo {
|
||||
shard2QueryNodes := make(map[string][]NodeInfo)
|
||||
|
||||
for _, leaders := range shardsLeaders {
|
||||
qns := make([]NodeInfo, len(leaders.GetNodeIds()))
|
||||
|
||||
for j := range qns {
|
||||
qns[j] = NodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j], leaders.GetServiceable()[j]}
|
||||
}
|
||||
|
||||
shard2QueryNodes[leaders.GetChannelName()] = qns
|
||||
}
|
||||
|
||||
return shard2QueryNodes
|
||||
}
|
||||
|
||||
// used for Garbage collection shard client
|
||||
func (m *shardClientMgrImpl) ListShardLocation() map[int64]NodeInfo {
|
||||
m.leaderMut.RLock()
|
||||
defer m.leaderMut.RUnlock()
|
||||
shardLeaderInfo := make(map[int64]NodeInfo)
|
||||
|
||||
for _, dbInfo := range m.collLeader {
|
||||
for _, shardLeaders := range dbInfo {
|
||||
for _, nodeInfos := range shardLeaders.shardLeaders {
|
||||
for _, node := range nodeInfos {
|
||||
shardLeaderInfo[node.NodeID] = node
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return shardLeaderInfo
|
||||
}
|
||||
|
||||
func (m *shardClientMgrImpl) RemoveDatabase(database string) {
|
||||
m.leaderMut.Lock()
|
||||
defer m.leaderMut.Unlock()
|
||||
delete(m.collLeader, database)
|
||||
}
|
||||
|
||||
// DeprecateShardCache clear the shard leader cache of a collection
|
||||
func (m *shardClientMgrImpl) DeprecateShardCache(database, collectionName string) {
|
||||
log.Info("deprecate shard cache for collection", zap.String("collectionName", collectionName))
|
||||
m.leaderMut.Lock()
|
||||
defer m.leaderMut.Unlock()
|
||||
dbInfo, ok := m.collLeader[database]
|
||||
if ok {
|
||||
delete(dbInfo, collectionName)
|
||||
if len(dbInfo) == 0 {
|
||||
delete(m.collLeader, database)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache called when Shard leader balance happened
|
||||
func (m *shardClientMgrImpl) InvalidateShardLeaderCache(collections []int64) {
|
||||
log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections))
|
||||
m.leaderMut.Lock()
|
||||
defer m.leaderMut.Unlock()
|
||||
collectionSet := typeutil.NewUniqueSet(collections...)
|
||||
for dbName, dbInfo := range m.collLeader {
|
||||
for collectionName, shardLeaders := range dbInfo {
|
||||
if collectionSet.Contain(shardLeaders.collectionID) {
|
||||
delete(dbInfo, collectionName)
|
||||
}
|
||||
}
|
||||
if len(dbInfo) == 0 {
|
||||
delete(m.collLeader, dbName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *shardClientMgrImpl) GetClient(ctx context.Context, info NodeInfo) (types.QueryNodeClient, error) {
|
||||
client, _ := c.clients.GetOrInsert(info.NodeID, newShardClient(info, c.clientCreator, c.expiredDuration))
|
||||
return client.getClient(ctx)
|
||||
}
|
||||
|
||||
// PurgeClient purges client if it is not used for a long time
|
||||
func (c *shardClientMgrImpl) PurgeClient() {
|
||||
ticker := time.NewTicker(c.purgeInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
shardLocations := c.ListShardLocation()
|
||||
c.clients.Range(func(key UniqueID, value *shardClient) bool {
|
||||
if _, ok := shardLocations[key]; !ok {
|
||||
// if the client is not used for more than 1 hour, and it's not a delegator anymore, should remove it
|
||||
if value.isExpired() {
|
||||
closed := value.Close(false)
|
||||
if closed {
|
||||
c.clients.Remove(key)
|
||||
log.Info("remove idle node client", zap.Int64("nodeID", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *shardClientMgrImpl) Start() {
|
||||
go c.PurgeClient()
|
||||
}
|
||||
|
||||
// Close release clients
|
||||
func (c *shardClientMgrImpl) Close() {
|
||||
close(c.closeCh)
|
||||
c.clients.Range(func(key UniqueID, value *shardClient) bool {
|
||||
value.Close(true)
|
||||
c.clients.Remove(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package proxy
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
context "context"
|
||||
@ -89,7 +89,7 @@ func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Cl
|
||||
}
|
||||
|
||||
// RegisterNodeInfo provides a mock function with given fields: nodeInfos
|
||||
func (_m *MockLBBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {
|
||||
func (_m *MockLBBalancer) RegisterNodeInfo(nodeInfos []NodeInfo) {
|
||||
_m.Called(nodeInfos)
|
||||
}
|
||||
|
||||
@ -99,14 +99,14 @@ type MockLBBalancer_RegisterNodeInfo_Call struct {
|
||||
}
|
||||
|
||||
// RegisterNodeInfo is a helper method to define mock.On call
|
||||
// - nodeInfos []nodeInfo
|
||||
// - nodeInfos []NodeInfo
|
||||
func (_e *MockLBBalancer_Expecter) RegisterNodeInfo(nodeInfos interface{}) *MockLBBalancer_RegisterNodeInfo_Call {
|
||||
return &MockLBBalancer_RegisterNodeInfo_Call{Call: _e.mock.On("RegisterNodeInfo", nodeInfos)}
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_RegisterNodeInfo_Call) Run(run func(nodeInfos []nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
|
||||
func (_c *MockLBBalancer_RegisterNodeInfo_Call) Run(run func(nodeInfos []NodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]nodeInfo))
|
||||
run(args[0].([]NodeInfo))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
@ -116,7 +116,7 @@ func (_c *MockLBBalancer_RegisterNodeInfo_Call) Return() *MockLBBalancer_Registe
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockLBBalancer_RegisterNodeInfo_Call) RunAndReturn(run func([]nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
|
||||
func (_c *MockLBBalancer_RegisterNodeInfo_Call) RunAndReturn(run func([]NodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package proxy
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
context "context"
|
||||
465
internal/proxy/shardclient/mock_shardclient_manager.go
Normal file
465
internal/proxy/shardclient/mock_shardclient_manager.go
Normal file
@ -0,0 +1,465 @@
|
||||
// Code generated by mockery v2.53.3. DO NOT EDIT.
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
context "context"
|
||||
|
||||
types "github.com/milvus-io/milvus/internal/types"
|
||||
mock "github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockShardClientManager is an autogenerated mock type for the ShardClientMgr type
|
||||
type MockShardClientManager struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
type MockShardClientManager_Expecter struct {
|
||||
mock *mock.Mock
|
||||
}
|
||||
|
||||
func (_m *MockShardClientManager) EXPECT() *MockShardClientManager_Expecter {
|
||||
return &MockShardClientManager_Expecter{mock: &_m.Mock}
|
||||
}
|
||||
|
||||
// Close provides a mock function with no fields
|
||||
func (_m *MockShardClientManager) Close() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockShardClientManager_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
|
||||
type MockShardClientManager_Close_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Close is a helper method to define mock.On call
|
||||
func (_e *MockShardClientManager_Expecter) Close() *MockShardClientManager_Close_Call {
|
||||
return &MockShardClientManager_Close_Call{Call: _e.mock.On("Close")}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Close_Call) Run(run func()) *MockShardClientManager_Close_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Close_Call) Return() *MockShardClientManager_Close_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShardClientManager_Close_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// DeprecateShardCache provides a mock function with given fields: database, collectionName
|
||||
func (_m *MockShardClientManager) DeprecateShardCache(database string, collectionName string) {
|
||||
_m.Called(database, collectionName)
|
||||
}
|
||||
|
||||
// MockShardClientManager_DeprecateShardCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeprecateShardCache'
|
||||
type MockShardClientManager_DeprecateShardCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// DeprecateShardCache is a helper method to define mock.On call
|
||||
// - database string
|
||||
// - collectionName string
|
||||
func (_e *MockShardClientManager_Expecter) DeprecateShardCache(database interface{}, collectionName interface{}) *MockShardClientManager_DeprecateShardCache_Call {
|
||||
return &MockShardClientManager_DeprecateShardCache_Call{Call: _e.mock.On("DeprecateShardCache", database, collectionName)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_DeprecateShardCache_Call) Run(run func(database string, collectionName string)) *MockShardClientManager_DeprecateShardCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string), args[1].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_DeprecateShardCache_Call) Return() *MockShardClientManager_DeprecateShardCache_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_DeprecateShardCache_Call) RunAndReturn(run func(string, string)) *MockShardClientManager_DeprecateShardCache_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetClient provides a mock function with given fields: ctx, nodeInfo
|
||||
func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeInfo NodeInfo) (types.QueryNodeClient, error) {
|
||||
ret := _m.Called(ctx, nodeInfo)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetClient")
|
||||
}
|
||||
|
||||
var r0 types.QueryNodeClient
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, NodeInfo) (types.QueryNodeClient, error)); ok {
|
||||
return rf(ctx, nodeInfo)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, NodeInfo) types.QueryNodeClient); ok {
|
||||
r0 = rf(ctx, nodeInfo)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(types.QueryNodeClient)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, NodeInfo) error); ok {
|
||||
r1 = rf(ctx, nodeInfo)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardClientManager_GetClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetClient'
|
||||
type MockShardClientManager_GetClient_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetClient is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - nodeInfo NodeInfo
|
||||
func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeInfo interface{}) *MockShardClientManager_GetClient_Call {
|
||||
return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeInfo)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeInfo NodeInfo)) *MockShardClientManager_GetClient_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(NodeInfo))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClient, _a1 error) *MockShardClientManager_GetClient_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, NodeInfo) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetShard provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID, channel
|
||||
func (_m *MockShardClientManager) GetShard(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string) ([]NodeInfo, error) {
|
||||
ret := _m.Called(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetShard")
|
||||
}
|
||||
|
||||
var r0 []NodeInfo
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) ([]NodeInfo, error)); ok {
|
||||
return rf(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) []NodeInfo); ok {
|
||||
r0 = rf(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]NodeInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64, string) error); ok {
|
||||
r1 = rf(ctx, withCache, database, collectionName, collectionID, channel)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardClientManager_GetShard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShard'
|
||||
type MockShardClientManager_GetShard_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetShard is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - withCache bool
|
||||
// - database string
|
||||
// - collectionName string
|
||||
// - collectionID int64
|
||||
// - channel string
|
||||
func (_e *MockShardClientManager_Expecter) GetShard(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}, channel interface{}) *MockShardClientManager_GetShard_Call {
|
||||
return &MockShardClientManager_GetShard_Call{Call: _e.mock.On("GetShard", ctx, withCache, database, collectionName, collectionID, channel)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetShard_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string)) *MockShardClientManager_GetShard_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64), args[5].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetShard_Call) Return(_a0 []NodeInfo, _a1 error) *MockShardClientManager_GetShard_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetShard_Call) RunAndReturn(run func(context.Context, bool, string, string, int64, string) ([]NodeInfo, error)) *MockShardClientManager_GetShard_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// GetShardLeaderList provides a mock function with given fields: ctx, database, collectionName, collectionID, withCache
|
||||
func (_m *MockShardClientManager) GetShardLeaderList(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool) ([]string, error) {
|
||||
ret := _m.Called(ctx, database, collectionName, collectionID, withCache)
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for GetShardLeaderList")
|
||||
}
|
||||
|
||||
var r0 []string
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) ([]string, error)); ok {
|
||||
return rf(ctx, database, collectionName, collectionID, withCache)
|
||||
}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) []string); ok {
|
||||
r0 = rf(ctx, database, collectionName, collectionID, withCache)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]string)
|
||||
}
|
||||
}
|
||||
|
||||
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, bool) error); ok {
|
||||
r1 = rf(ctx, database, collectionName, collectionID, withCache)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// MockShardClientManager_GetShardLeaderList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShardLeaderList'
|
||||
type MockShardClientManager_GetShardLeaderList_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// GetShardLeaderList is a helper method to define mock.On call
|
||||
// - ctx context.Context
|
||||
// - database string
|
||||
// - collectionName string
|
||||
// - collectionID int64
|
||||
// - withCache bool
|
||||
func (_e *MockShardClientManager_Expecter) GetShardLeaderList(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}, withCache interface{}) *MockShardClientManager_GetShardLeaderList_Call {
|
||||
return &MockShardClientManager_GetShardLeaderList_Call{Call: _e.mock.On("GetShardLeaderList", ctx, database, collectionName, collectionID, withCache)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetShardLeaderList_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool)) *MockShardClientManager_GetShardLeaderList_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64), args[4].(bool))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetShardLeaderList_Call) Return(_a0 []string, _a1 error) *MockShardClientManager_GetShardLeaderList_Call {
|
||||
_c.Call.Return(_a0, _a1)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_GetShardLeaderList_Call) RunAndReturn(run func(context.Context, string, string, int64, bool) ([]string, error)) *MockShardClientManager_GetShardLeaderList_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache provides a mock function with given fields: collections
|
||||
func (_m *MockShardClientManager) InvalidateShardLeaderCache(collections []int64) {
|
||||
_m.Called(collections)
|
||||
}
|
||||
|
||||
// MockShardClientManager_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
|
||||
type MockShardClientManager_InvalidateShardLeaderCache_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// InvalidateShardLeaderCache is a helper method to define mock.On call
|
||||
// - collections []int64
|
||||
func (_e *MockShardClientManager_Expecter) InvalidateShardLeaderCache(collections interface{}) *MockShardClientManager_InvalidateShardLeaderCache_Call {
|
||||
return &MockShardClientManager_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", collections)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_InvalidateShardLeaderCache_Call) Run(run func(collections []int64)) *MockShardClientManager_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].([]int64))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_InvalidateShardLeaderCache_Call) Return() *MockShardClientManager_InvalidateShardLeaderCache_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int64)) *MockShardClientManager_InvalidateShardLeaderCache_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// ListShardLocation provides a mock function with no fields
|
||||
func (_m *MockShardClientManager) ListShardLocation() map[int64]NodeInfo {
|
||||
ret := _m.Called()
|
||||
|
||||
if len(ret) == 0 {
|
||||
panic("no return value specified for ListShardLocation")
|
||||
}
|
||||
|
||||
var r0 map[int64]NodeInfo
|
||||
if rf, ok := ret.Get(0).(func() map[int64]NodeInfo); ok {
|
||||
r0 = rf()
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[int64]NodeInfo)
|
||||
}
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// MockShardClientManager_ListShardLocation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListShardLocation'
|
||||
type MockShardClientManager_ListShardLocation_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// ListShardLocation is a helper method to define mock.On call
|
||||
func (_e *MockShardClientManager_Expecter) ListShardLocation() *MockShardClientManager_ListShardLocation_Call {
|
||||
return &MockShardClientManager_ListShardLocation_Call{Call: _e.mock.On("ListShardLocation")}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_ListShardLocation_Call) Run(run func()) *MockShardClientManager_ListShardLocation_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_ListShardLocation_Call) Return(_a0 map[int64]NodeInfo) *MockShardClientManager_ListShardLocation_Call {
|
||||
_c.Call.Return(_a0)
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_ListShardLocation_Call) RunAndReturn(run func() map[int64]NodeInfo) *MockShardClientManager_ListShardLocation_Call {
|
||||
_c.Call.Return(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// RemoveDatabase provides a mock function with given fields: database
|
||||
func (_m *MockShardClientManager) RemoveDatabase(database string) {
|
||||
_m.Called(database)
|
||||
}
|
||||
|
||||
// MockShardClientManager_RemoveDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveDatabase'
|
||||
type MockShardClientManager_RemoveDatabase_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// RemoveDatabase is a helper method to define mock.On call
|
||||
// - database string
|
||||
func (_e *MockShardClientManager_Expecter) RemoveDatabase(database interface{}) *MockShardClientManager_RemoveDatabase_Call {
|
||||
return &MockShardClientManager_RemoveDatabase_Call{Call: _e.mock.On("RemoveDatabase", database)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_RemoveDatabase_Call) Run(run func(database string)) *MockShardClientManager_RemoveDatabase_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(string))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_RemoveDatabase_Call) Return() *MockShardClientManager_RemoveDatabase_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_RemoveDatabase_Call) RunAndReturn(run func(string)) *MockShardClientManager_RemoveDatabase_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// SetClientCreatorFunc provides a mock function with given fields: creator
|
||||
func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
|
||||
_m.Called(creator)
|
||||
}
|
||||
|
||||
// MockShardClientManager_SetClientCreatorFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetClientCreatorFunc'
|
||||
type MockShardClientManager_SetClientCreatorFunc_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// SetClientCreatorFunc is a helper method to define mock.On call
|
||||
// - creator queryNodeCreatorFunc
|
||||
func (_e *MockShardClientManager_Expecter) SetClientCreatorFunc(creator interface{}) *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
return &MockShardClientManager_SetClientCreatorFunc_Call{Call: _e.mock.On("SetClientCreatorFunc", creator)}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Run(run func(creator queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run(args[0].(queryNodeCreatorFunc))
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Return() *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run func(queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// Start provides a mock function with no fields
|
||||
func (_m *MockShardClientManager) Start() {
|
||||
_m.Called()
|
||||
}
|
||||
|
||||
// MockShardClientManager_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
|
||||
type MockShardClientManager_Start_Call struct {
|
||||
*mock.Call
|
||||
}
|
||||
|
||||
// Start is a helper method to define mock.On call
|
||||
func (_e *MockShardClientManager_Expecter) Start() *MockShardClientManager_Start_Call {
|
||||
return &MockShardClientManager_Start_Call{Call: _e.mock.On("Start")}
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Start_Call) Run(run func()) *MockShardClientManager_Start_Call {
|
||||
_c.Call.Run(func(args mock.Arguments) {
|
||||
run()
|
||||
})
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Start_Call) Return() *MockShardClientManager_Start_Call {
|
||||
_c.Call.Return()
|
||||
return _c
|
||||
}
|
||||
|
||||
func (_c *MockShardClientManager_Start_Call) RunAndReturn(run func()) *MockShardClientManager_Start_Call {
|
||||
_c.Run(run)
|
||||
return _c
|
||||
}
|
||||
|
||||
// NewMockShardClientManager creates a new instance of MockShardClientManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
|
||||
// The first argument is typically a *testing.T value.
|
||||
func NewMockShardClientManager(t interface {
|
||||
mock.TestingT
|
||||
Cleanup(func())
|
||||
}) *MockShardClientManager {
|
||||
mock := &MockShardClientManager{}
|
||||
mock.Mock.Test(t)
|
||||
|
||||
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||
|
||||
return mock
|
||||
}
|
||||
76
internal/proxy/shardclient/model.go
Normal file
76
internal/proxy/shardclient/model.go
Normal file
@ -0,0 +1,76 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/samber/lo"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// shardLeaders wraps shard leader mapping for iteration.
|
||||
type shardLeaders struct {
|
||||
idx *atomic.Int64
|
||||
collectionID int64
|
||||
shardLeaders map[string][]NodeInfo
|
||||
}
|
||||
|
||||
func (sl *shardLeaders) Get(channel string) []NodeInfo {
|
||||
return sl.shardLeaders[channel]
|
||||
}
|
||||
|
||||
func (sl *shardLeaders) GetShardLeaderList() []string {
|
||||
return lo.Keys(sl.shardLeaders)
|
||||
}
|
||||
|
||||
type shardLeadersReader struct {
|
||||
leaders *shardLeaders
|
||||
idx int64
|
||||
}
|
||||
|
||||
// Shuffle returns the shuffled shard leader list.
|
||||
func (it shardLeadersReader) Shuffle() map[string][]NodeInfo {
|
||||
result := make(map[string][]NodeInfo)
|
||||
for channel, leaders := range it.leaders.shardLeaders {
|
||||
l := len(leaders)
|
||||
// shuffle all replica at random order
|
||||
shuffled := make([]NodeInfo, l)
|
||||
for i, randIndex := range rand.Perm(l) {
|
||||
shuffled[i] = leaders[randIndex]
|
||||
}
|
||||
|
||||
// make each copy has same probability to be first replica
|
||||
for index, leader := range shuffled {
|
||||
if leader == leaders[int(it.idx)%l] {
|
||||
shuffled[0], shuffled[index] = shuffled[index], shuffled[0]
|
||||
}
|
||||
}
|
||||
|
||||
result[channel] = shuffled
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetReader returns shuffer reader for shard leader.
|
||||
func (sl *shardLeaders) GetReader() shardLeadersReader {
|
||||
idx := sl.idx.Inc()
|
||||
return shardLeadersReader{
|
||||
leaders: sl,
|
||||
idx: idx,
|
||||
}
|
||||
}
|
||||
@ -13,7 +13,8 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package proxy
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -32,7 +33,7 @@ func NewRoundRobinBalancer() *RoundRobinBalancer {
|
||||
return &RoundRobinBalancer{}
|
||||
}
|
||||
|
||||
func (b *RoundRobinBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {}
|
||||
func (b *RoundRobinBalancer) RegisterNodeInfo(nodeInfos []NodeInfo) {}
|
||||
|
||||
func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) {
|
||||
if len(availableNodes) == 0 {
|
||||
@ -13,7 +13,8 @@
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
package proxy
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -1,4 +1,20 @@
|
||||
package proxy
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
@ -10,30 +26,31 @@ import (
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/registry"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
type UniqueID = typeutil.UniqueID
|
||||
|
||||
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)
|
||||
|
||||
type nodeInfo struct {
|
||||
nodeID UniqueID
|
||||
address string
|
||||
serviceable bool
|
||||
type NodeInfo struct {
|
||||
NodeID UniqueID
|
||||
Address string
|
||||
Serviceable bool
|
||||
}
|
||||
|
||||
func (n nodeInfo) String() string {
|
||||
return fmt.Sprintf("<NodeID: %d, serviceable: %v, address: %s>", n.nodeID, n.serviceable, n.address)
|
||||
func (n NodeInfo) String() string {
|
||||
return fmt.Sprintf("<NodeID: %d, serviceable: %v, address: %s>", n.NodeID, n.Serviceable, n.Address)
|
||||
}
|
||||
|
||||
var errClosed = errors.New("client is closed")
|
||||
|
||||
type shardClient struct {
|
||||
sync.RWMutex
|
||||
info nodeInfo
|
||||
info NodeInfo
|
||||
poolSize int
|
||||
clients []types.QueryNodeClient
|
||||
creator queryNodeCreatorFunc
|
||||
@ -46,7 +63,7 @@ type shardClient struct {
|
||||
expiredDuration time.Duration
|
||||
}
|
||||
|
||||
func newShardClient(info nodeInfo, creator queryNodeCreatorFunc, expiredDuration time.Duration) *shardClient {
|
||||
func newShardClient(info NodeInfo, creator queryNodeCreatorFunc, expiredDuration time.Duration) *shardClient {
|
||||
return &shardClient{
|
||||
info: info,
|
||||
creator: creator,
|
||||
@ -89,14 +106,14 @@ func (n *shardClient) initClients(ctx context.Context) error {
|
||||
|
||||
clients := make([]types.QueryNodeClient, 0, poolSize)
|
||||
for i := 0; i < poolSize; i++ {
|
||||
client, err := n.creator(ctx, n.info.address, n.info.nodeID)
|
||||
client, err := n.creator(ctx, n.info.Address, n.info.NodeID)
|
||||
if err != nil {
|
||||
// Roll back already created clients
|
||||
for _, c := range clients {
|
||||
c.Close()
|
||||
}
|
||||
log.Info("failed to create client for node", zap.Int64("nodeID", n.info.nodeID), zap.Error(err))
|
||||
return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID))
|
||||
log.Info("failed to create client for node", zap.Int64("nodeID", n.info.NodeID), zap.Error(err))
|
||||
return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.NodeID))
|
||||
}
|
||||
clients = append(clients, client)
|
||||
}
|
||||
@ -150,104 +167,3 @@ func (n *shardClient) close() {
|
||||
}
|
||||
n.clients = nil
|
||||
}
|
||||
|
||||
// roundRobinSelectClient selects a client in a round-robin manner
|
||||
type shardClientMgr interface {
|
||||
GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error)
|
||||
SetClientCreatorFunc(creator queryNodeCreatorFunc)
|
||||
|
||||
Start()
|
||||
Close()
|
||||
}
|
||||
|
||||
type shardClientMgrImpl struct {
|
||||
clients *typeutil.ConcurrentMap[UniqueID, *shardClient]
|
||||
clientCreator queryNodeCreatorFunc
|
||||
closeCh chan struct{}
|
||||
|
||||
purgeInterval time.Duration
|
||||
expiredDuration time.Duration
|
||||
}
|
||||
|
||||
const (
|
||||
defaultPurgeInterval = 600 * time.Second
|
||||
defaultExpiredDuration = 60 * time.Minute
|
||||
)
|
||||
|
||||
// SessionOpt provides a way to set params in SessionManager
|
||||
type shardClientMgrOpt func(s shardClientMgr)
|
||||
|
||||
func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
|
||||
return func(s shardClientMgr) { s.SetClientCreatorFunc(creator) }
|
||||
}
|
||||
|
||||
func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return registry.GetInMemoryResolver().ResolveQueryNode(ctx, addr, nodeID)
|
||||
}
|
||||
|
||||
// NewShardClientMgr creates a new shardClientMgr
|
||||
func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl {
|
||||
s := &shardClientMgrImpl{
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: defaultQueryNodeClientCreator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: defaultPurgeInterval,
|
||||
expiredDuration: defaultExpiredDuration,
|
||||
}
|
||||
for _, opt := range options {
|
||||
opt(s)
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
|
||||
c.clientCreator = creator
|
||||
}
|
||||
|
||||
func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (types.QueryNodeClient, error) {
|
||||
client, _ := c.clients.GetOrInsert(info.nodeID, newShardClient(info, c.clientCreator, c.expiredDuration))
|
||||
return client.getClient(ctx)
|
||||
}
|
||||
|
||||
// PurgeClient purges client if it is not used for a long time
|
||||
func (c *shardClientMgrImpl) PurgeClient() {
|
||||
ticker := time.NewTicker(c.purgeInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.closeCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
shardLocations := globalMetaCache.ListShardLocation()
|
||||
c.clients.Range(func(key UniqueID, value *shardClient) bool {
|
||||
if _, ok := shardLocations[key]; !ok {
|
||||
// if the client is not used for more than 1 hour, and it's not a delegator anymore, should remove it
|
||||
if value.isExpired() {
|
||||
closed := value.Close(false)
|
||||
if closed {
|
||||
c.clients.Remove(key)
|
||||
log.Info("remove idle node client", zap.Int64("nodeID", key))
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *shardClientMgrImpl) Start() {
|
||||
go c.PurgeClient()
|
||||
}
|
||||
|
||||
// Close release clients
|
||||
func (c *shardClientMgrImpl) Close() {
|
||||
close(c.closeCh)
|
||||
c.clients.Range(func(key UniqueID, value *shardClient) bool {
|
||||
value.Close(true)
|
||||
c.clients.Remove(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
602
internal/proxy/shardclient/shard_client_test.go
Normal file
602
internal/proxy/shardclient/shard_client_test.go
Normal file
@ -0,0 +1,602 @@
|
||||
// Licensed to the LF AI & Data foundation under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shardclient
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func TestShardClientMgr(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
nodeInfo := NodeInfo{
|
||||
NodeID: 1,
|
||||
}
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil)
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
mixcoord := mocks.NewMockMixCoordClient(t)
|
||||
|
||||
mgr := NewShardClientMgr(mixcoord)
|
||||
mgr.SetClientCreatorFunc(creator)
|
||||
_, err := mgr.GetClient(ctx, nodeInfo)
|
||||
assert.Nil(t, err)
|
||||
|
||||
mgr.Close()
|
||||
assert.Equal(t, mgr.clients.Len(), 0)
|
||||
}
|
||||
|
||||
func TestShardClient(t *testing.T) {
|
||||
nodeInfo := NodeInfo{
|
||||
NodeID: 1,
|
||||
}
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
shardClient := newShardClient(nodeInfo, creator, 3*time.Second)
|
||||
assert.Equal(t, len(shardClient.clients), 0)
|
||||
assert.Equal(t, false, shardClient.initialized.Load())
|
||||
assert.Equal(t, false, shardClient.isClosed)
|
||||
|
||||
ctx := context.Background()
|
||||
_, err := shardClient.getClient(ctx)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
|
||||
|
||||
// test close
|
||||
closed := shardClient.Close(false)
|
||||
assert.False(t, closed)
|
||||
closed = shardClient.Close(true)
|
||||
assert.True(t, closed)
|
||||
}
|
||||
|
||||
func TestPurgeClient(t *testing.T) {
|
||||
node := NodeInfo{
|
||||
NodeID: 1,
|
||||
}
|
||||
|
||||
returnEmptyResult := atomic.NewBool(false)
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
s := &shardClientMgrImpl{
|
||||
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
clientCreator: creator,
|
||||
closeCh: make(chan struct{}),
|
||||
purgeInterval: 1 * time.Second,
|
||||
expiredDuration: 3 * time.Second,
|
||||
|
||||
collLeader: map[string]map[string]*shardLeaders{
|
||||
"default": {
|
||||
"test": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 1,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"0": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.PurgeClient()
|
||||
defer s.Close()
|
||||
_, err := s.GetClient(context.Background(), node)
|
||||
assert.Nil(t, err)
|
||||
qnClient, ok := s.clients.Get(1)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, qnClient.lastActiveTs.Load() > 0)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// expected client should not been purged before expiredDuration
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds())
|
||||
|
||||
_, err = s.GetClient(context.Background(), node)
|
||||
assert.Nil(t, err)
|
||||
time.Sleep(2 * time.Second)
|
||||
// GetClient should refresh lastActiveTs, expected client should not be purged
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds())
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
// client reach the expiredDuration, expected client should not be purged
|
||||
assert.Equal(t, s.clients.Len(), 1)
|
||||
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds())
|
||||
|
||||
s.DeprecateShardCache("default", "test")
|
||||
returnEmptyResult.Store(true)
|
||||
time.Sleep(2 * time.Second)
|
||||
// remove client from shard location, expected client should be purged
|
||||
assert.Eventually(t, func() bool {
|
||||
return s.clients.Len() == 0
|
||||
}, 10*time.Second, 1*time.Second)
|
||||
}
|
||||
|
||||
func TestDeprecateShardCache(t *testing.T) {
|
||||
node := NodeInfo{
|
||||
NodeID: 1,
|
||||
}
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
mixcoord := mocks.NewMockMixCoordClient(t)
|
||||
mgr := NewShardClientMgr(mixcoord)
|
||||
mgr.SetClientCreatorFunc(creator)
|
||||
|
||||
t.Run("Clear with no collection info", func(t *testing.T) {
|
||||
mgr.DeprecateShardCache("default", "collection_not_exist")
|
||||
// Should not panic or error
|
||||
})
|
||||
|
||||
t.Run("Clear valid collection empty cache", func(t *testing.T) {
|
||||
// Add a collection to cache first
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"default": {
|
||||
"test_collection": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.DeprecateShardCache("default", "test_collection")
|
||||
|
||||
// Verify cache is cleared
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["default"]["test_collection"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear one collection, keep others", func(t *testing.T) {
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"default": {
|
||||
"collection1": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
"collection2": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 101,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-2": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.DeprecateShardCache("default", "collection1")
|
||||
|
||||
// Verify collection1 is cleared but collection2 remains
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["default"]["collection1"]
|
||||
assert.False(t, exists)
|
||||
_, exists = mgr.collLeader["default"]["collection2"]
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Clear last collection in database removes database", func(t *testing.T) {
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"test_db": {
|
||||
"last_collection": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 200,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.DeprecateShardCache("test_db", "last_collection")
|
||||
|
||||
// Verify database is also removed
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["test_db"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
mgr.Close()
|
||||
}
|
||||
|
||||
func TestInvalidateShardLeaderCache(t *testing.T) {
|
||||
node := NodeInfo{
|
||||
NodeID: 1,
|
||||
}
|
||||
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
qn.EXPECT().Close().Return(nil).Maybe()
|
||||
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
return qn, nil
|
||||
}
|
||||
|
||||
mixcoord := mocks.NewMockMixCoordClient(t)
|
||||
mgr := NewShardClientMgr(mixcoord)
|
||||
mgr.SetClientCreatorFunc(creator)
|
||||
|
||||
t.Run("Invalidate single collection", func(t *testing.T) {
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"default": {
|
||||
"collection1": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
"collection2": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 101,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-2": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.InvalidateShardLeaderCache([]int64{100})
|
||||
|
||||
// Verify collection with ID 100 is removed, but 101 remains
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["default"]["collection1"]
|
||||
assert.False(t, exists)
|
||||
_, exists = mgr.collLeader["default"]["collection2"]
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Invalidate multiple collections", func(t *testing.T) {
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"default": {
|
||||
"collection1": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
"collection2": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 101,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-2": {node},
|
||||
},
|
||||
},
|
||||
"collection3": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 102,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-3": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.InvalidateShardLeaderCache([]int64{100, 102})
|
||||
|
||||
// Verify collections 100 and 102 are removed, but 101 remains
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["default"]["collection1"]
|
||||
assert.False(t, exists)
|
||||
_, exists = mgr.collLeader["default"]["collection2"]
|
||||
assert.True(t, exists)
|
||||
_, exists = mgr.collLeader["default"]["collection3"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Invalidate non-existent collection", func(t *testing.T) {
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"default": {
|
||||
"collection1": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.InvalidateShardLeaderCache([]int64{999})
|
||||
|
||||
// Verify collection1 still exists
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["default"]["collection1"]
|
||||
assert.True(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Invalidate all collections in database removes database", func(t *testing.T) {
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"test_db": {
|
||||
"collection1": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 200,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
"collection2": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 201,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-2": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.InvalidateShardLeaderCache([]int64{200, 201})
|
||||
|
||||
// Verify database is removed
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["test_db"]
|
||||
assert.False(t, exists)
|
||||
})
|
||||
|
||||
t.Run("Invalidate across multiple databases", func(t *testing.T) {
|
||||
mgr.leaderMut.Lock()
|
||||
mgr.collLeader = map[string]map[string]*shardLeaders{
|
||||
"db1": {
|
||||
"collection1": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-1": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
"db2": {
|
||||
"collection2": {
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100, // Same collection ID in different database
|
||||
shardLeaders: map[string][]NodeInfo{
|
||||
"channel-2": {node},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
mgr.leaderMut.Unlock()
|
||||
|
||||
mgr.InvalidateShardLeaderCache([]int64{100})
|
||||
|
||||
// Verify collection is removed from both databases
|
||||
mgr.leaderMut.RLock()
|
||||
defer mgr.leaderMut.RUnlock()
|
||||
_, exists := mgr.collLeader["db1"]
|
||||
assert.False(t, exists) // db1 should be removed
|
||||
_, exists = mgr.collLeader["db2"]
|
||||
assert.False(t, exists) // db2 should be removed
|
||||
})
|
||||
|
||||
mgr.Close()
|
||||
}
|
||||
|
||||
func TestShuffleShardLeaders(t *testing.T) {
|
||||
t.Run("Shuffle with multiple nodes", func(t *testing.T) {
|
||||
shards := map[string][]NodeInfo{
|
||||
"channel-1": {
|
||||
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
||||
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
||||
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
||||
},
|
||||
}
|
||||
sl := &shardLeaders{
|
||||
idx: atomic.NewInt64(5),
|
||||
collectionID: 100,
|
||||
shardLeaders: shards,
|
||||
}
|
||||
|
||||
reader := sl.GetReader()
|
||||
result := reader.Shuffle()
|
||||
|
||||
// Verify result has same channel
|
||||
assert.Len(t, result, 1)
|
||||
assert.Contains(t, result, "channel-1")
|
||||
|
||||
// Verify all nodes are present
|
||||
assert.Len(t, result["channel-1"], 3)
|
||||
|
||||
// Verify the first node is based on idx rotation (idx=6, 6%3=0, so nodeID 1 should be first)
|
||||
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
|
||||
|
||||
// Verify all nodes are still present (shuffled)
|
||||
nodeIDs := make(map[int64]bool)
|
||||
for _, node := range result["channel-1"] {
|
||||
nodeIDs[node.NodeID] = true
|
||||
}
|
||||
assert.True(t, nodeIDs[1])
|
||||
assert.True(t, nodeIDs[2])
|
||||
assert.True(t, nodeIDs[3])
|
||||
})
|
||||
|
||||
t.Run("Shuffle rotates first replica based on idx", func(t *testing.T) {
|
||||
shards := map[string][]NodeInfo{
|
||||
"channel-1": {
|
||||
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
||||
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
||||
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
||||
},
|
||||
}
|
||||
sl := &shardLeaders{
|
||||
idx: atomic.NewInt64(5),
|
||||
collectionID: 100,
|
||||
shardLeaders: shards,
|
||||
}
|
||||
|
||||
// First read, idx will be 6 (5+1), 6%3=0, so first replica should be leaders[0] which is nodeID 1
|
||||
reader := sl.GetReader()
|
||||
result := reader.Shuffle()
|
||||
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
|
||||
|
||||
// Second read, idx will be 7 (6+1), 7%3=1, so first replica should be leaders[1] which is nodeID 2
|
||||
reader = sl.GetReader()
|
||||
result = reader.Shuffle()
|
||||
assert.Equal(t, int64(2), result["channel-1"][0].NodeID)
|
||||
|
||||
// Third read, idx will be 8 (7+1), 8%3=2, so first replica should be leaders[2] which is nodeID 3
|
||||
reader = sl.GetReader()
|
||||
result = reader.Shuffle()
|
||||
assert.Equal(t, int64(3), result["channel-1"][0].NodeID)
|
||||
})
|
||||
|
||||
t.Run("Shuffle with single node", func(t *testing.T) {
|
||||
shards := map[string][]NodeInfo{
|
||||
"channel-1": {
|
||||
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
||||
},
|
||||
}
|
||||
sl := &shardLeaders{
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: shards,
|
||||
}
|
||||
|
||||
reader := sl.GetReader()
|
||||
result := reader.Shuffle()
|
||||
|
||||
assert.Len(t, result["channel-1"], 1)
|
||||
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
|
||||
})
|
||||
|
||||
t.Run("Shuffle with multiple channels", func(t *testing.T) {
|
||||
shards := map[string][]NodeInfo{
|
||||
"channel-1": {
|
||||
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
|
||||
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
|
||||
},
|
||||
"channel-2": {
|
||||
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
|
||||
{NodeID: 4, Address: "localhost:9003", Serviceable: true},
|
||||
},
|
||||
}
|
||||
sl := &shardLeaders{
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: shards,
|
||||
}
|
||||
|
||||
reader := sl.GetReader()
|
||||
result := reader.Shuffle()
|
||||
|
||||
// Verify both channels are present
|
||||
assert.Len(t, result, 2)
|
||||
assert.Contains(t, result, "channel-1")
|
||||
assert.Contains(t, result, "channel-2")
|
||||
|
||||
// Verify each channel has correct number of nodes
|
||||
assert.Len(t, result["channel-1"], 2)
|
||||
assert.Len(t, result["channel-2"], 2)
|
||||
})
|
||||
|
||||
t.Run("Shuffle with empty leaders", func(t *testing.T) {
|
||||
shards := map[string][]NodeInfo{}
|
||||
sl := &shardLeaders{
|
||||
idx: atomic.NewInt64(0),
|
||||
collectionID: 100,
|
||||
shardLeaders: shards,
|
||||
}
|
||||
|
||||
reader := sl.GetReader()
|
||||
result := reader.Shuffle()
|
||||
|
||||
assert.Len(t, result, 0)
|
||||
})
|
||||
}
|
||||
|
||||
// func BenchmarkShardClientMgr(b *testing.B) {
|
||||
// node := nodeInfo{
|
||||
// nodeID: 1,
|
||||
// }
|
||||
// cache := NewMockCache(b)
|
||||
// cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{
|
||||
// 1: node,
|
||||
// }).Maybe()
|
||||
// globalMetaCache = cache
|
||||
// qn := mocks.NewMockQueryNodeClient(b)
|
||||
// qn.EXPECT().Close().Return(nil).Maybe()
|
||||
|
||||
// creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
|
||||
// return qn, nil
|
||||
// }
|
||||
// s := &shardClientMgrImpl{
|
||||
// clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
|
||||
// clientCreator: creator,
|
||||
// closeCh: make(chan struct{}),
|
||||
// purgeInterval: 1 * time.Second,
|
||||
// expiredDuration: 10 * time.Second,
|
||||
// }
|
||||
// go s.PurgeClient()
|
||||
// defer s.Close()
|
||||
|
||||
// b.ResetTimer()
|
||||
// b.RunParallel(func(pb *testing.PB) {
|
||||
// for pb.Next() {
|
||||
// _, err := s.GetClient(context.Background(), node)
|
||||
// assert.Nil(b, err)
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
@ -30,6 +30,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
@ -3030,7 +3031,7 @@ type RunAnalyzerTask struct {
|
||||
collectionID typeutil.UniqueID
|
||||
fieldID typeutil.UniqueID
|
||||
dbName string
|
||||
lb LBPolicy
|
||||
lb shardclient.LBPolicy
|
||||
|
||||
result *milvuspb.RunAnalyzerResponse
|
||||
}
|
||||
@ -3122,12 +3123,12 @@ func (t *RunAnalyzerTask) runAnalyzerOnShardleader(ctx context.Context, nodeID i
|
||||
}
|
||||
|
||||
func (t *RunAnalyzerTask) Execute(ctx context.Context) error {
|
||||
err := t.lb.ExecuteOneChannel(ctx, CollectionWorkLoad{
|
||||
db: t.dbName,
|
||||
collectionName: t.GetCollectionName(),
|
||||
collectionID: t.collectionID,
|
||||
nq: int64(len(t.GetPlaceholder())),
|
||||
exec: t.runAnalyzerOnShardleader,
|
||||
err := t.lb.ExecuteOneChannel(ctx, shardclient.CollectionWorkLoad{
|
||||
Db: t.dbName,
|
||||
CollectionName: t.GetCollectionName(),
|
||||
CollectionID: t.collectionID,
|
||||
Nq: int64(len(t.GetPlaceholder())),
|
||||
Exec: t.runAnalyzerOnShardleader,
|
||||
})
|
||||
|
||||
return err
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/allocator"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/segcore"
|
||||
@ -249,7 +250,7 @@ type deleteRunner struct {
|
||||
// for query
|
||||
msgID int64
|
||||
ts uint64
|
||||
lb LBPolicy
|
||||
lb shardclient.LBPolicy
|
||||
count atomic.Int64
|
||||
|
||||
// task queue
|
||||
@ -409,7 +410,7 @@ func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs,
|
||||
|
||||
// getStreamingQueryAndDelteFunc return query function used by LBPolicy
|
||||
// make sure it concurrent safe
|
||||
func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) executeFunc {
|
||||
func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) shardclient.ExecuteFunc {
|
||||
return func(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
|
||||
log := log.Ctx(ctx).With(
|
||||
zap.Int64("collectionID", dr.collectionID),
|
||||
@ -556,12 +557,12 @@ func (dr *deleteRunner) complexDelete(ctx context.Context, plan *planpb.PlanNode
|
||||
return err
|
||||
}
|
||||
|
||||
err = dr.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: dr.req.GetDbName(),
|
||||
collectionName: dr.req.GetCollectionName(),
|
||||
collectionID: dr.collectionID,
|
||||
nq: 1,
|
||||
exec: dr.getStreamingQueryAndDelteFunc(plan),
|
||||
err = dr.lb.Execute(ctx, shardclient.CollectionWorkLoad{
|
||||
Db: dr.req.GetDbName(),
|
||||
CollectionName: dr.req.GetCollectionName(),
|
||||
CollectionID: dr.collectionID,
|
||||
Nq: 1,
|
||||
Exec: dr.getStreamingQueryAndDelteFunc(plan),
|
||||
})
|
||||
dr.result.DeleteCnt = dr.count.Load()
|
||||
dr.result.Timestamp = dr.sessionTS.Load()
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"github.com/milvus-io/milvus/internal/distributed/streaming"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/util/streamrpc"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
@ -683,7 +684,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
|
||||
t.Run("simple delete task failed", func(t *testing.T) {
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
|
||||
expr := "pk in [1,2,3]"
|
||||
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
|
||||
@ -722,7 +723,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
t.Run("complex delete query rpc failed", func(t *testing.T) {
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
expr := "pk < 3"
|
||||
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
|
||||
require.NoError(t, err)
|
||||
@ -749,8 +750,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
},
|
||||
plan: plan,
|
||||
}
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
|
||||
return workload.Exec(ctx, 1, qn, "")
|
||||
})
|
||||
|
||||
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
||||
@ -764,7 +765,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
expr := "pk < 3"
|
||||
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
|
||||
require.NoError(t, err)
|
||||
@ -795,8 +796,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
}
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
|
||||
return workload.Exec(ctx, 1, qn, "")
|
||||
})
|
||||
|
||||
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
|
||||
@ -830,7 +831,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
expr := "pk < 3"
|
||||
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
|
||||
require.NoError(t, err)
|
||||
@ -860,8 +861,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
},
|
||||
plan: plan,
|
||||
}
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
|
||||
return workload.Exec(ctx, 1, qn, "")
|
||||
})
|
||||
|
||||
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
|
||||
@ -893,7 +894,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
expr := "pk < 3"
|
||||
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
|
||||
require.NoError(t, err)
|
||||
@ -923,8 +924,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
plan: plan,
|
||||
}
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
|
||||
return workload.Exec(ctx, 1, qn, "")
|
||||
})
|
||||
|
||||
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
|
||||
@ -957,7 +958,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
expr := "pk < 3"
|
||||
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
|
||||
require.NoError(t, err)
|
||||
@ -987,8 +988,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
plan: plan,
|
||||
}
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
|
||||
return workload.Exec(ctx, 1, qn, "")
|
||||
})
|
||||
|
||||
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
|
||||
@ -1025,7 +1026,7 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
|
||||
mockMgr := NewMockChannelsMgr(t)
|
||||
qn := mocks.NewMockQueryNodeClient(t)
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
|
||||
mockCache := NewMockCache(t)
|
||||
mockCache.EXPECT().GetCollectionID(mock.Anything, dbName, collectionName).Return(collectionID, nil).Maybe()
|
||||
@ -1059,8 +1060,8 @@ func TestDeleteRunner_Run(t *testing.T) {
|
||||
plan: plan,
|
||||
}
|
||||
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
|
||||
return workload.exec(ctx, 1, qn, "")
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
|
||||
return workload.Exec(ctx, 1, qn, "")
|
||||
})
|
||||
|
||||
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
|
||||
|
||||
@ -85,9 +85,8 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
|
||||
collectionID: collectionID,
|
||||
}
|
||||
|
||||
shardMgr := newShardClientMgr()
|
||||
// failed to get collection id.
|
||||
err := InitMetaCache(ctx, queryCoord, shardMgr)
|
||||
err := InitMetaCache(ctx, queryCoord)
|
||||
assert.NoError(t, err)
|
||||
assert.Error(t, gist.Execute(ctx))
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
package proxy
|
||||
|
||||
/*
|
||||
import (
|
||||
"context"
|
||||
|
||||
@ -7,6 +8,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||
@ -16,7 +18,7 @@ import (
|
||||
|
||||
type queryFunc func(context.Context, UniqueID, types.QueryNodeClient, ...string) error
|
||||
|
||||
type pickShardPolicy func(context.Context, shardClientMgr, queryFunc, map[string][]nodeInfo) error
|
||||
type pickShardPolicy func(context.Context, shardclient.ShardClientMgr, queryFunc, map[string][]nodeInfo) error
|
||||
|
||||
var errInvalidShardLeaders = errors.New("Invalid shard leader")
|
||||
|
||||
@ -24,7 +26,7 @@ var errInvalidShardLeaders = errors.New("Invalid shard leader")
|
||||
// if request failed, it finds shard leader for failed dml channels
|
||||
func RoundRobinPolicy(
|
||||
ctx context.Context,
|
||||
mgr shardClientMgr,
|
||||
mgr shardclient.ShardClientMgr,
|
||||
query queryFunc,
|
||||
dml2leaders map[string][]nodeInfo,
|
||||
) error {
|
||||
@ -65,3 +67,4 @@ func RoundRobinPolicy(
|
||||
err := wg.Wait()
|
||||
return err
|
||||
}
|
||||
*/
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
package proxy
|
||||
|
||||
/*
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
@ -83,3 +84,4 @@ func (m *mockQuery) records() map[UniqueID][]string {
|
||||
}
|
||||
return m.queryset
|
||||
}
|
||||
*/
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proxy/accesslog"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
@ -69,7 +70,8 @@ type queryTask struct {
|
||||
|
||||
plan *planpb.PlanNode
|
||||
partitionKeyMode bool
|
||||
lb LBPolicy
|
||||
shardclientMgr shardclient.ShardClientMgr
|
||||
lb shardclient.LBPolicy
|
||||
channelsMvcc map[string]Timestamp
|
||||
fastSkip bool
|
||||
|
||||
@ -575,12 +577,12 @@ func (t *queryTask) Execute(ctx context.Context) error {
|
||||
zap.String("requestType", "query"))
|
||||
|
||||
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]()
|
||||
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: t.request.GetDbName(),
|
||||
collectionID: t.CollectionID,
|
||||
collectionName: t.collectionName,
|
||||
nq: 1,
|
||||
exec: t.queryShard,
|
||||
err := t.lb.Execute(ctx, shardclient.CollectionWorkLoad{
|
||||
Db: t.request.GetDbName(),
|
||||
CollectionID: t.CollectionID,
|
||||
CollectionName: t.collectionName,
|
||||
Nq: 1,
|
||||
Exec: t.queryShard,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("fail to execute query", zap.Error(err))
|
||||
@ -730,13 +732,13 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
|
||||
result, err := qn.Query(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("QueryNode query return error", zap.Error(err))
|
||||
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
t.shardclientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Warn("QueryNode is not shardLeader")
|
||||
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
return errInvalidShardLeaders
|
||||
t.shardclientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
return merr.Error(result.GetStatus())
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("QueryNode query result error", zap.Any("errorCode", result.GetStatus().GetErrorCode()), zap.String("reason", result.GetStatus().GetReason()))
|
||||
|
||||
@ -31,6 +31,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/util/reduce"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
@ -60,11 +61,16 @@ func TestQueryTask_all(t *testing.T) {
|
||||
|
||||
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr := shardclient.NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
|
||||
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
|
||||
{NodeID: 1, Address: "mock_qn", Serviceable: true},
|
||||
}, nil).Maybe()
|
||||
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
lb := shardclient.NewLBPolicyImpl(mgr)
|
||||
|
||||
err = InitMetaCache(ctx, qc, mgr)
|
||||
err = InitMetaCache(ctx, qc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fieldName2Types := map[string]schemapb.DataType{
|
||||
@ -141,8 +147,9 @@ func TestQueryTask_all(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
mixCoord: qc,
|
||||
lb: lb,
|
||||
mixCoord: qc,
|
||||
lb: lb,
|
||||
shardclientMgr: mgr,
|
||||
}
|
||||
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
@ -290,9 +297,10 @@ func TestQueryTask_all(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
mixCoord: qc,
|
||||
lb: lb,
|
||||
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
||||
mixCoord: qc,
|
||||
lb: lb,
|
||||
shardclientMgr: mgr,
|
||||
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
||||
}
|
||||
// simulate scheduler enqueue task
|
||||
enqueTs := uint64(10000)
|
||||
@ -341,9 +349,10 @@ func TestQueryTask_all(t *testing.T) {
|
||||
},
|
||||
GuaranteeTimestamp: enqueTs,
|
||||
},
|
||||
mixCoord: qc,
|
||||
lb: lb,
|
||||
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
||||
mixCoord: qc,
|
||||
lb: lb,
|
||||
shardclientMgr: mgr,
|
||||
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
||||
}
|
||||
qtErr = qt.PreExecute(context.TODO())
|
||||
assert.Nil(t, qtErr)
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proxy/accesslog"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/exprutil"
|
||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
||||
@ -84,7 +85,8 @@ type searchTask struct {
|
||||
|
||||
mixCoord types.MixCoordClient
|
||||
node types.ProxyComponent
|
||||
lb LBPolicy
|
||||
lb shardclient.LBPolicy
|
||||
shardClientMgr shardclient.ShardClientMgr
|
||||
queryChannelsTs map[string]Timestamp
|
||||
queryInfos []*planpb.QueryInfo
|
||||
relatedDataSize int64
|
||||
@ -718,12 +720,12 @@ func (t *searchTask) Execute(ctx context.Context) error {
|
||||
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
|
||||
defer tr.CtxElapse(ctx, "done")
|
||||
|
||||
err := t.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: t.request.GetDbName(),
|
||||
collectionID: t.SearchRequest.CollectionID,
|
||||
collectionName: t.collectionName,
|
||||
nq: t.Nq,
|
||||
exec: t.searchShard,
|
||||
err := t.lb.Execute(ctx, shardclient.CollectionWorkLoad{
|
||||
Db: t.request.GetDbName(),
|
||||
CollectionID: t.SearchRequest.CollectionID,
|
||||
CollectionName: t.collectionName,
|
||||
Nq: t.Nq,
|
||||
Exec: t.searchShard,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn("search execute failed", zap.Error(err))
|
||||
@ -928,13 +930,14 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
|
||||
result, err = qn.Search(ctx, req)
|
||||
if err != nil {
|
||||
log.Warn("QueryNode search return error", zap.Error(err))
|
||||
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
// globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
t.shardClientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Warn("QueryNode is not shardLeader")
|
||||
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
return errInvalidShardLeaders
|
||||
t.shardClientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
|
||||
return merr.Error(result.GetStatus())
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Warn("QueryNode search result error",
|
||||
|
||||
@ -39,6 +39,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/internal/util/dependency"
|
||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
||||
@ -67,9 +68,8 @@ func TestSearchTask_PostExecute(t *testing.T) {
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
mgr := newShardClientMgr()
|
||||
|
||||
err = InitMetaCache(ctx, qc, mgr)
|
||||
err = InitMetaCache(ctx, qc)
|
||||
require.NoError(t, err)
|
||||
|
||||
getSearchTask := func(t *testing.T, collName string) *searchTask {
|
||||
@ -638,8 +638,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||
ctx = context.TODO()
|
||||
)
|
||||
require.NoError(t, err)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, qc, mgr)
|
||||
err = InitMetaCache(ctx, qc)
|
||||
require.NoError(t, err)
|
||||
|
||||
getSearchTask := func(t *testing.T, collName string) *searchTask {
|
||||
@ -1039,8 +1038,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, qc, mgr)
|
||||
err = InitMetaCache(ctx, qc)
|
||||
require.NoError(t, err)
|
||||
|
||||
getSearchTask := func(t *testing.T, collName string, data []string, withRerank bool) *searchTask {
|
||||
@ -1105,8 +1103,6 @@ func TestSearchTask_WithFunctions(t *testing.T) {
|
||||
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe()
|
||||
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
|
||||
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
|
||||
cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe()
|
||||
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
globalMetaCache = cache
|
||||
|
||||
{
|
||||
@ -1266,8 +1262,7 @@ func TestSearchTaskV2_Execute(t *testing.T) {
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, qc, mgr)
|
||||
err = InitMetaCache(ctx, qc)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer qc.Close()
|
||||
@ -2925,13 +2920,18 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
collectionName = t.Name() + funcutil.GenRandomStr()
|
||||
)
|
||||
|
||||
mgr := NewMockShardClientManager(t)
|
||||
mgr := shardclient.NewMockShardClientManager(t)
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
lb := NewLBPolicyImpl(mgr)
|
||||
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
|
||||
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
|
||||
{NodeID: 1, Address: "mock_qn", Serviceable: true},
|
||||
}, nil).Maybe()
|
||||
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
lb := shardclient.NewLBPolicyImpl(mgr)
|
||||
|
||||
defer rc.Close()
|
||||
|
||||
err = InitMetaCache(ctx, rc, mgr)
|
||||
err = InitMetaCache(ctx, rc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
fieldName2Types := map[string]schemapb.DataType{
|
||||
@ -3027,8 +3027,9 @@ func TestSearchTask_ErrExecute(t *testing.T) {
|
||||
Nq: 2,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
},
|
||||
mixCoord: rc,
|
||||
lb: lb,
|
||||
mixCoord: rc,
|
||||
lb: lb,
|
||||
shardClientMgr: mgr,
|
||||
}
|
||||
for i := 0; i < len(fieldName2Types); i++ {
|
||||
task.SearchRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
|
||||
@ -4066,11 +4067,18 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schema, nil).Maybe()
|
||||
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
|
||||
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
|
||||
cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe()
|
||||
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Maybe()
|
||||
globalMetaCache = cache
|
||||
|
||||
mgr := shardclient.NewMockShardClientManager(t)
|
||||
// mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
||||
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
|
||||
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
|
||||
{NodeID: 1, Address: "mock_qn", Serviceable: true},
|
||||
}, nil).Maybe()
|
||||
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
node.shardMgr = mgr
|
||||
|
||||
t.Run("Test normal", func(t *testing.T) {
|
||||
collSchema := constructCollectionSchema(pkField, vecField, dim, collection)
|
||||
schema := newSchemaInfo(collSchema)
|
||||
@ -4115,9 +4123,9 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
}, nil
|
||||
})
|
||||
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
|
||||
err = workload.exec(ctx, 0, qn, "")
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload shardclient.CollectionWorkLoad) {
|
||||
err = workload.Exec(ctx, 0, qn, "")
|
||||
assert.NoError(t, err)
|
||||
}).Return(nil)
|
||||
lb.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything).Return()
|
||||
@ -4153,6 +4161,7 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
translatedOutputFields: outputFields,
|
||||
shardClientMgr: mgr,
|
||||
}
|
||||
op, err := newRequeryOperator(qt, nil)
|
||||
assert.NoError(t, err)
|
||||
@ -4199,9 +4208,9 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, errors.New("mock err 1"))
|
||||
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
|
||||
_ = workload.exec(ctx, 0, qn, "")
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload shardclient.CollectionWorkLoad) {
|
||||
_ = workload.Exec(ctx, 0, qn, "")
|
||||
}).Return(errors.New("mock err 1"))
|
||||
node.lbPolicy = lb
|
||||
|
||||
@ -4216,9 +4225,10 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
request: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
schema: schema,
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
shardClientMgr: mgr,
|
||||
}
|
||||
|
||||
op, err := newRequeryOperator(qt, nil)
|
||||
@ -4235,9 +4245,9 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
qn.EXPECT().Query(mock.Anything, mock.Anything).
|
||||
Return(nil, errors.New("mock err 1"))
|
||||
|
||||
lb := NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
|
||||
_ = workload.exec(ctx, 0, qn, "")
|
||||
lb := shardclient.NewMockLBPolicy(t)
|
||||
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload shardclient.CollectionWorkLoad) {
|
||||
_ = workload.Exec(ctx, 0, qn, "")
|
||||
}).Return(errors.New("mock err 1"))
|
||||
node.lbPolicy = lb
|
||||
|
||||
@ -4265,11 +4275,12 @@ func TestSearchTask_Requery(t *testing.T) {
|
||||
Ids: resultIDs,
|
||||
},
|
||||
},
|
||||
needRequery: true,
|
||||
schema: schema,
|
||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
needRequery: true,
|
||||
schema: schema,
|
||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||
tr: timerecord.NewTimeRecorder("search"),
|
||||
node: node,
|
||||
shardClientMgr: mgr,
|
||||
}
|
||||
scores := make([]float32, rows)
|
||||
for i := range scores {
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
|
||||
@ -54,7 +55,8 @@ type getStatisticsTask struct {
|
||||
*internalpb.GetStatisticsRequest
|
||||
resultBuf *typeutil.ConcurrentSet[*internalpb.GetStatisticsResponse]
|
||||
|
||||
lb LBPolicy
|
||||
shardclientMgr shardclient.ShardClientMgr
|
||||
lb shardclient.LBPolicy
|
||||
}
|
||||
|
||||
func (g *getStatisticsTask) TraceCtx() context.Context {
|
||||
@ -258,12 +260,12 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
|
||||
if g.resultBuf == nil {
|
||||
g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]()
|
||||
}
|
||||
err := g.lb.Execute(ctx, CollectionWorkLoad{
|
||||
db: g.request.GetDbName(),
|
||||
collectionID: g.GetStatisticsRequest.CollectionID,
|
||||
collectionName: g.collectionName,
|
||||
nq: 1,
|
||||
exec: g.getStatisticsShard,
|
||||
err := g.lb.Execute(ctx, shardclient.CollectionWorkLoad{
|
||||
Db: g.request.GetDbName(),
|
||||
CollectionID: g.GetStatisticsRequest.CollectionID,
|
||||
CollectionName: g.collectionName,
|
||||
Nq: 1,
|
||||
Exec: g.getStatisticsShard,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to statistic")
|
||||
@ -286,15 +288,15 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.String("channel", channel),
|
||||
zap.Error(err))
|
||||
globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
|
||||
g.shardclientMgr.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
|
||||
return err
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
|
||||
log.Ctx(ctx).Warn("QueryNode is not shardLeader",
|
||||
zap.Int64("nodeID", nodeID),
|
||||
zap.String("channel", channel))
|
||||
globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
|
||||
return errInvalidShardLeaders
|
||||
g.shardclientMgr.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
|
||||
return merr.Error(result.GetStatus())
|
||||
}
|
||||
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
log.Ctx(ctx).Warn("QueryNode statistic result error",
|
||||
|
||||
@ -29,6 +29,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||
@ -43,7 +44,8 @@ type StatisticTaskSuite struct {
|
||||
mixc types.MixCoordClient
|
||||
qn *mocks.MockQueryNodeClient
|
||||
|
||||
lb LBPolicy
|
||||
lb shardclient.LBPolicy
|
||||
mgr shardclient.ShardClientMgr
|
||||
|
||||
collectionName string
|
||||
collectionID int64
|
||||
@ -82,11 +84,17 @@ func (s *StatisticTaskSuite) SetupTest() {
|
||||
s.qn = mocks.NewMockQueryNodeClient(s.T())
|
||||
|
||||
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
||||
mgr := NewMockShardClientManager(s.T())
|
||||
mgr := shardclient.NewMockShardClientManager(s.T())
|
||||
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe()
|
||||
s.lb = NewLBPolicyImpl(mgr)
|
||||
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
|
||||
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
|
||||
{NodeID: 1, Address: "mock_qn", Serviceable: true},
|
||||
}, nil).Maybe()
|
||||
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
||||
s.mgr = mgr
|
||||
s.lb = shardclient.NewLBPolicyImpl(mgr)
|
||||
|
||||
err := InitMetaCache(context.Background(), s.mixc, mgr)
|
||||
err := InitMetaCache(context.Background(), s.mixc)
|
||||
s.NoError(err)
|
||||
|
||||
s.collectionName = "test_statistics_task"
|
||||
@ -178,8 +186,9 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
|
||||
},
|
||||
CollectionName: s.collectionName,
|
||||
},
|
||||
mixc: s.mixc,
|
||||
lb: s.lb,
|
||||
mixc: s.mixc,
|
||||
lb: s.lb,
|
||||
shardclientMgr: s.mgr,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -473,8 +473,7 @@ func TestAlterCollection_AllowInsertAutoID_Validation(t *testing.T) {
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
root := buildRoot(true)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, root, mgr)
|
||||
err := InitMetaCache(ctx, root)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := &alterCollectionTask{
|
||||
@ -495,8 +494,7 @@ func TestAlterCollection_AllowInsertAutoID_Validation(t *testing.T) {
|
||||
cache := globalMetaCache
|
||||
defer func() { globalMetaCache = cache }()
|
||||
root := buildRoot(false)
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, root, mgr)
|
||||
err := InitMetaCache(ctx, root)
|
||||
assert.NoError(t, err)
|
||||
|
||||
task := &alterCollectionTask{
|
||||
@ -1602,8 +1600,7 @@ func TestHasCollectionTask(t *testing.T) {
|
||||
mixc := NewMixCoordMock()
|
||||
defer mixc.Close()
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mixc, mgr)
|
||||
InitMetaCache(ctx, mixc)
|
||||
prefix := "TestHasCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
@ -1699,8 +1696,7 @@ func TestDescribeCollectionTask(t *testing.T) {
|
||||
mixc := NewMixCoordMock()
|
||||
defer mixc.Close()
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mixc, mgr)
|
||||
InitMetaCache(ctx, mixc)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
@ -1757,8 +1753,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
|
||||
mix := NewMixCoordMock()
|
||||
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mix, mgr)
|
||||
InitMetaCache(ctx, mix)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
@ -1816,8 +1811,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
|
||||
func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) {
|
||||
mix := NewMixCoordMock()
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mix, mgr)
|
||||
InitMetaCache(ctx, mix)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
@ -1876,8 +1870,7 @@ func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) {
|
||||
func TestDescribeCollectionTask_ShardsNum2(t *testing.T) {
|
||||
mix := NewMixCoordMock()
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mix, mgr)
|
||||
InitMetaCache(ctx, mix)
|
||||
prefix := "TestDescribeCollectionTask"
|
||||
dbName := ""
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
@ -2238,8 +2231,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
|
||||
qc := NewMixCoordMock()
|
||||
ctx := context.Background()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, qc, mgr)
|
||||
err = InitMetaCache(ctx, qc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := int32(2)
|
||||
@ -2467,8 +2459,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err = InitMetaCache(ctx, mixc, mgr)
|
||||
err = InitMetaCache(ctx, mixc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := int32(2)
|
||||
@ -3086,9 +3077,8 @@ func Test_loadCollectionTask_Execute(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
indexID := int64(1000)
|
||||
|
||||
shardMgr := newShardClientMgr()
|
||||
// failed to get collection id.
|
||||
_ = InitMetaCache(ctx, rc, shardMgr)
|
||||
_ = InitMetaCache(ctx, rc)
|
||||
|
||||
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
return &milvuspb.DescribeCollectionResponse{
|
||||
@ -3212,9 +3202,8 @@ func Test_loadPartitionTask_Execute(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
indexID := int64(1000)
|
||||
|
||||
shardMgr := newShardClientMgr()
|
||||
// failed to get collection id.
|
||||
_ = InitMetaCache(ctx, qc, shardMgr)
|
||||
_ = InitMetaCache(ctx, qc)
|
||||
|
||||
qc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) {
|
||||
return &milvuspb.DescribeCollectionResponse{
|
||||
@ -3297,8 +3286,7 @@ func TestCreateResourceGroupTask(t *testing.T) {
|
||||
defer mixc.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mixc, mgr)
|
||||
InitMetaCache(ctx, mixc)
|
||||
|
||||
createRGReq := &milvuspb.CreateResourceGroupRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -3335,8 +3323,7 @@ func TestDropResourceGroupTask(t *testing.T) {
|
||||
defer mixc.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mixc, mgr)
|
||||
InitMetaCache(ctx, mixc)
|
||||
|
||||
dropRGReq := &milvuspb.DropResourceGroupRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -3372,8 +3359,7 @@ func TestTransferNodeTask(t *testing.T) {
|
||||
|
||||
defer mixc.Close()
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, mixc, mgr)
|
||||
InitMetaCache(ctx, mixc)
|
||||
|
||||
req := &milvuspb.TransferNodeRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -3410,8 +3396,7 @@ func TestTransferReplicaTask(t *testing.T) {
|
||||
rc := &MockMixCoordClientInterface{}
|
||||
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, rc, mgr)
|
||||
InitMetaCache(ctx, rc)
|
||||
// make it avoid remote call on rc
|
||||
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
|
||||
|
||||
@ -3458,8 +3443,7 @@ func TestListResourceGroupsTask(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, rc, mgr)
|
||||
InitMetaCache(ctx, rc)
|
||||
|
||||
req := &milvuspb.ListResourceGroupsRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -3508,8 +3492,7 @@ func TestDescribeResourceGroupTask(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, rc, mgr)
|
||||
InitMetaCache(ctx, rc)
|
||||
// make it avoid remote call on rc
|
||||
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
|
||||
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
|
||||
@ -3558,8 +3541,7 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) {
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
InitMetaCache(ctx, rc, mgr)
|
||||
InitMetaCache(ctx, rc)
|
||||
// make it avoid remote call on rc
|
||||
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
|
||||
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
|
||||
@ -3807,7 +3789,7 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// check default partitions
|
||||
err = InitMetaCache(ctx, rc, nil)
|
||||
err = InitMetaCache(ctx, rc)
|
||||
assert.NoError(t, err)
|
||||
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", task.CollectionName)
|
||||
assert.NoError(t, err)
|
||||
@ -3886,8 +3868,7 @@ func TestPartitionKey(t *testing.T) {
|
||||
defer qc.Close()
|
||||
ctx := context.Background()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, qc, mgr)
|
||||
err := InitMetaCache(ctx, qc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := common.DefaultShardsNum
|
||||
@ -4121,8 +4102,7 @@ func TestDefaultPartition(t *testing.T) {
|
||||
qc := NewMixCoordMock()
|
||||
ctx := context.Background()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, qc, mgr)
|
||||
err := InitMetaCache(ctx, qc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := common.DefaultShardsNum
|
||||
@ -4302,8 +4282,7 @@ func TestClusteringKey(t *testing.T) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, qc, mgr)
|
||||
err := InitMetaCache(ctx, qc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
shardsNum := common.DefaultShardsNum
|
||||
@ -4439,7 +4418,7 @@ func TestClusteringKey(t *testing.T) {
|
||||
|
||||
func TestAlterCollectionCheckLoaded(t *testing.T) {
|
||||
qc := NewMixCoordMock()
|
||||
InitMetaCache(context.Background(), qc, nil)
|
||||
InitMetaCache(context.Background(), qc)
|
||||
collectionName := "test_alter_collection_check_loaded"
|
||||
createColReq := &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -4481,8 +4460,7 @@ func TestAlterCollectionCheckLoaded(t *testing.T) {
|
||||
func TestTaskPartitionKeyIsolation(t *testing.T) {
|
||||
qc := NewMixCoordMock()
|
||||
ctx := context.Background()
|
||||
mgr := newShardClientMgr()
|
||||
err := InitMetaCache(ctx, qc, mgr)
|
||||
err := InitMetaCache(ctx, qc)
|
||||
assert.NoError(t, err)
|
||||
shardsNum := common.DefaultShardsNum
|
||||
prefix := "TestPartitionKeyIsolation"
|
||||
@ -4828,7 +4806,7 @@ func TestInsertForReplicate(t *testing.T) {
|
||||
|
||||
func TestAlterCollectionFieldCheckLoaded(t *testing.T) {
|
||||
qc := NewMixCoordMock()
|
||||
InitMetaCache(context.Background(), qc, nil)
|
||||
InitMetaCache(context.Background(), qc)
|
||||
collectionName := "test_alter_collection_field_check_loaded"
|
||||
createColReq := &milvuspb.CreateCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
@ -4880,7 +4858,7 @@ func TestAlterCollectionFieldCheckLoaded(t *testing.T) {
|
||||
|
||||
func TestAlterCollectionField(t *testing.T) {
|
||||
qc := NewMixCoordMock()
|
||||
InitMetaCache(context.Background(), qc, nil)
|
||||
InitMetaCache(context.Background(), qc)
|
||||
collectionName := "test_alter_collection_field"
|
||||
|
||||
// Create collection with string and array fields
|
||||
@ -5483,7 +5461,7 @@ func TestDescribeCollectionTaskWithStructArrayField(t *testing.T) {
|
||||
|
||||
func TestAlterCollection_AllowInsertAutoID_AutoIDFalse(t *testing.T) {
|
||||
qc := NewMixCoordMock()
|
||||
InitMetaCache(context.Background(), qc, nil)
|
||||
InitMetaCache(context.Background(), qc)
|
||||
ctx := context.Background()
|
||||
collectionName := "test_alter_allow_insert_autoid_autoid_false"
|
||||
|
||||
|
||||
@ -32,6 +32,7 @@ import (
|
||||
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
||||
"github.com/milvus-io/milvus/internal/mocks"
|
||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
||||
"github.com/milvus-io/milvus/internal/util/segcore"
|
||||
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
||||
@ -600,7 +601,7 @@ func createTestUpdateTask() *upsertTask {
|
||||
collectionID: 1001,
|
||||
node: &Proxy{
|
||||
mixCoord: mcClient,
|
||||
lbPolicy: NewLBPolicyImpl(nil),
|
||||
lbPolicy: shardclient.NewLBPolicyImpl(nil),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user