enhance: [2.6] extract shard client logic into dedicated package (#45018) (#45031)

Cherry-pick from master
pr: #45018 #45030
Related to #44761

Refactor proxy shard client management by creating a new
internal/proxy/shardclient package. This improves code organization and
modularity by:

- Moving load balancing logic (LookAsideBalancer, RoundRobinBalancer) to
shardclient package
- Extracting shard client manager and related interfaces into separate
package
- Relocating shard leader management and client lifecycle code
- Adding package documentation (README.md, OWNERS)
- Updating proxy code to use the new shardclient package interfaces

This change makes the shard client functionality more maintainable and
better encapsulated, reducing coupling in the proxy layer.

Also consolidates the proxy package's mockery generation to use a
centralized `.mockery.yaml` configuration file, aligning with the
pattern used by other packages like querycoordv2.

Changes
- **Makefile**: Replace multiple individual mockery commands with a
single config-based invocation for `generate-mockery-proxy` target
- **internal/proxy/.mockery.yaml**: Add mockery configuration defining
all mock interfaces for proxy and proxy/shardclient packages
- **Mock files**: Regenerate mocks using the new configuration:
  - `mock_cache.go`: Clean up by removing unused interface methods
  (credential, shard cache, policy methods)
  - `shardclient/mock_lb_balancer.go`: Update type comments
  (nodeInfo → NodeInfo)
  - `shardclient/mock_lb_policy.go`: Update formatting
  - `shardclient/mock_shardclient_manager.go`: Fix parameter naming
  consistency (nodeInfo1 → nodeInfo)
- **task_search_test.go**: Remove obsolete mock expectations for
deprecated cache methods

Benefits
- Centralized mockery configuration for easier maintenance
- Consistent with other packages (querycoordv2, etc.)
- Cleaner mock interfaces by removing unused methods
- Better type consistency in generated mocks

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
congqixia 2025-10-22 16:06:06 +08:00 committed by GitHub
parent b3e525609c
commit a592cfc8b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
48 changed files with 3000 additions and 2706 deletions

View File

@ -478,12 +478,7 @@ generate-mockery-rootcoord: getdeps
$(INSTALL_PATH)/mockery --name=IMetaTable --dir=$(PWD)/internal/rootcoord --output=$(PWD)/internal/rootcoord/mocks --filename=meta_table.go --with-expecter --outpkg=mockrootcoord
generate-mockery-proxy: getdeps
$(INSTALL_PATH)/mockery --name=Cache --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_cache.go --structname=MockCache --with-expecter --outpkg=proxy --inpackage
$(INSTALL_PATH)/mockery --name=timestampAllocatorInterface --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_tso_test.go --structname=mockTimestampAllocator --with-expecter --outpkg=proxy --inpackage
$(INSTALL_PATH)/mockery --name=LBPolicy --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_policy.go --structname=MockLBPolicy --with-expecter --outpkg=proxy --inpackage
$(INSTALL_PATH)/mockery --name=LBBalancer --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_lb_balancer.go --structname=MockLBBalancer --with-expecter --outpkg=proxy --inpackage
$(INSTALL_PATH)/mockery --name=shardClientMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_shardclient_manager.go --structname=MockShardClientManager --with-expecter --outpkg=proxy --inpackage
$(INSTALL_PATH)/mockery --name=channelsMgr --dir=$(PWD)/internal/proxy --output=$(PWD)/internal/proxy --filename=mock_channels_manager.go --structname=MockChannelsMgr --with-expecter --outpkg=proxy --inpackage
$(INSTALL_PATH)/mockery --config $(PWD)/internal/proxy/.mockery.yaml
generate-mockery-querycoord: getdeps
$(INSTALL_PATH)/mockery --config $(PWD)/internal/querycoordv2/.mockery.yaml

View File

@ -0,0 +1,31 @@
quiet: False
with-expecter: True
inpackage: True
filename: "mock_{{.InterfaceNameSnake}}.go"
mockname: "Mock{{.InterfaceName}}"
outpkg: "{{.PackageName}}"
dir: "{{.InterfaceDir}}"
packages:
github.com/milvus-io/milvus/internal/proxy:
interfaces:
Cache:
timestampAllocatorInterface:
config:
mockname: mockTimestampAllocator
filename: mock_tso_test.go
channelsMgr:
config:
mockname: MockChannelsMgr
filename: mock_channels_manager.go
github.com/milvus-io/milvus/internal/proxy/shardclient:
interfaces:
LBPolicy:
config:
filename: mock_lb_policy.go
LBBalancer:
config:
filename: mock_lb_balancer.go
ShardClientMgr:
config:
filename: mock_shardclient_manager.go
mockname: MockShardClientManager

View File

@ -40,8 +40,7 @@ func TestValidAuth(t *testing.T) {
assert.False(t, res)
// normal metadata
mix := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, mix, mgr)
err := InitMetaCache(ctx, mix)
assert.NoError(t, err)
res = validAuth(ctx, []string{crypto.Base64Encode("mockUser:mockPass")})
assert.True(t, res)
@ -72,8 +71,7 @@ func TestAuthenticationInterceptor(t *testing.T) {
assert.Error(t, err)
// mock metacache
queryCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err = InitMetaCache(ctx, queryCoord, mgr)
err = InitMetaCache(ctx, queryCoord)
assert.NoError(t, err)
// with invalid metadata
md := metadata.Pairs("xxx", "yyy")

View File

@ -137,12 +137,12 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
if request.CollectionID != UniqueID(0) {
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, request.GetBase().GetTimestamp(), msgType == commonpb.MsgType_DropCollection)
for _, name := range aliasName {
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
}
}
if collectionName != "" {
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) // no need to return error, though collection may be not cached
globalMetaCache.DeprecateShardCache(request.GetDbName(), collectionName)
node.shardMgr.DeprecateShardCache(request.GetDbName(), collectionName)
}
log.Info("complete to invalidate collection meta cache with collection name", zap.String("type", request.GetBase().GetMsgType().String()))
case commonpb.MsgType_LoadCollection, commonpb.MsgType_ReleaseCollection:
@ -150,7 +150,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
if request.CollectionID != UniqueID(0) {
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, 0, false)
for _, name := range aliasName {
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
}
}
log.Info("complete to invalidate collection meta cache", zap.String("type", request.GetBase().GetMsgType().String()))
@ -165,13 +165,16 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
}
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName)
log.Info("complete to invalidate collection meta cache", zap.String("type", request.GetBase().GetMsgType().String()))
case commonpb.MsgType_DropDatabase, commonpb.MsgType_AlterDatabase:
case commonpb.MsgType_DropDatabase:
node.shardMgr.RemoveDatabase(request.GetDbName())
fallthrough
case commonpb.MsgType_AlterDatabase:
globalMetaCache.RemoveDatabase(ctx, request.GetDbName())
case commonpb.MsgType_AlterCollection, commonpb.MsgType_AlterCollectionField:
if request.CollectionID != UniqueID(0) {
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, 0, false)
for _, name := range aliasName {
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
}
}
if collectionName != "" {
@ -183,13 +186,13 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p
if request.CollectionID != UniqueID(0) {
aliasName = globalMetaCache.RemoveCollectionsByID(ctx, collectionID, request.GetBase().GetTimestamp(), false)
for _, name := range aliasName {
globalMetaCache.DeprecateShardCache(request.GetDbName(), name)
node.shardMgr.DeprecateShardCache(request.GetDbName(), name)
}
}
if collectionName != "" {
globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) // no need to return error, though collection may be not cached
globalMetaCache.DeprecateShardCache(request.GetDbName(), collectionName)
node.shardMgr.DeprecateShardCache(request.GetDbName(), collectionName)
}
}
}
@ -227,9 +230,8 @@ func (node *Proxy) InvalidateShardLeaderCache(ctx context.Context, request *prox
log.Info("received request to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
if globalMetaCache != nil {
globalMetaCache.InvalidateShardLeaderCache(request.GetCollectionIDs())
}
node.shardMgr.InvalidateShardLeaderCache(request.GetCollectionIDs())
log.Info("complete to invalidate shard leader cache", zap.Int64s("collectionIDs", request.GetCollectionIDs()))
return merr.Success(), nil
@ -2791,6 +2793,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
mixCoord: node.mixCoord,
node: node,
lb: node.lbPolicy,
shardClientMgr: node.shardMgr,
enableMaterializedView: node.enableMaterializedView,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
}
@ -3024,6 +3027,7 @@ func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSea
mixCoord: node.mixCoord,
node: node,
lb: node.lbPolicy,
shardClientMgr: node.shardMgr,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
}
@ -3399,6 +3403,7 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
request: request,
mixCoord: node.mixCoord,
lb: node.lbPolicy,
shardclientMgr: node.shardMgr,
mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),
}

View File

@ -40,6 +40,7 @@ import (
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
mhttp "github.com/milvus-io/milvus/internal/http"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/pkg/v2/common"
@ -247,8 +248,8 @@ func TestProxy_ResourceGroup(t *testing.T) {
node.sched.Start()
defer node.sched.Close()
mgr := newShardClientMgr()
InitMetaCache(ctx, qc, mgr)
// mgr := newShardClientMgr()
InitMetaCache(ctx, qc)
successStatus := &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
@ -329,8 +330,8 @@ func TestProxy_InvalidResourceGroupName(t *testing.T) {
node.sched.Start()
defer node.sched.Close()
mgr := newShardClientMgr()
InitMetaCache(ctx, qc, mgr)
// mgr := newShardClientMgr()
InitMetaCache(ctx, qc)
t.Run("create resource group", func(t *testing.T) {
resp, err := node.CreateResourceGroup(ctx, &milvuspb.CreateResourceGroupRequest{
@ -1162,7 +1163,7 @@ func TestProxyDescribeCollection(t *testing.T) {
}, nil).Maybe()
mixCoord.On("DescribeCollection", mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe()
var err error
globalMetaCache, err = NewMetaCache(mixCoord, nil)
globalMetaCache, err = NewMetaCache(mixCoord)
assert.NoError(t, err)
t.Run("not healthy", func(t *testing.T) {
@ -1589,9 +1590,9 @@ func TestProxy_InvalidateShardLeaderCache(t *testing.T) {
cacheBak := globalMetaCache
defer func() { globalMetaCache = cacheBak }()
// set expectations
cache := NewMockCache(t)
cache.EXPECT().InvalidateShardLeaderCache(mock.Anything)
globalMetaCache = cache
mockShardClientMgr := shardclient.NewMockShardClientManager(t)
mockShardClientMgr.EXPECT().InvalidateShardLeaderCache(mock.Anything).Return()
node.shardMgr = mockShardClientMgr
resp, err := node.InvalidateShardLeaderCache(context.TODO(), &proxypb.InvalidateShardLeaderCacheRequest{})
assert.NoError(t, err)
@ -1698,7 +1699,7 @@ func TestRunAnalyzer(t *testing.T) {
fieldMap: fieldMap,
}, nil)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
lb.EXPECT().ExecuteOneChannel(mock.Anything, mock.Anything).Return(nil)
p.lbPolicy = lb

View File

@ -1,765 +0,0 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"reflect"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type LBPolicySuite struct {
suite.Suite
qc types.MixCoordClient
qn *mocks.MockQueryNodeClient
mgr *MockShardClientManager
lbBalancer *MockLBBalancer
lbPolicy *LBPolicyImpl
nodeIDs []int64
nodes []nodeInfo
channels []string
qnList []*mocks.MockQueryNode
collectionName string
collectionID int64
}
func (s *LBPolicySuite) SetupSuite() {
paramtable.Init()
}
func (s *LBPolicySuite) SetupTest() {
s.nodeIDs = make([]int64, 0)
for i := 1; i <= 5; i++ {
s.nodeIDs = append(s.nodeIDs, int64(i))
s.nodes = append(s.nodes, nodeInfo{
nodeID: int64(i),
address: "localhost",
serviceable: true,
})
}
s.channels = []string{"channel1", "channel2"}
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
qc := NewMixCoordMock()
qc.GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
Serviceable: []bool{true, true, true, true, true},
},
{
ChannelName: s.channels[1],
NodeIds: s.nodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002", "localhost:9003", "localhost:9004"},
Serviceable: []bool{true, true, true, true, true},
},
},
}, nil
}
qc.ShowLoadPartitionsFunc = func(ctx context.Context, req *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) {
return &querypb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []int64{1, 2, 3},
}, nil
}
s.qc = qc
s.qn = mocks.NewMockQueryNodeClient(s.T())
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
s.mgr = NewMockShardClientManager(s.T())
s.lbBalancer = NewMockLBBalancer(s.T())
s.lbBalancer.EXPECT().Start(context.Background()).Maybe()
s.lbPolicy = NewLBPolicyImpl(s.mgr)
s.lbPolicy.Start(context.Background())
s.lbPolicy.getBalancer = func() LBBalancer {
return s.lbBalancer
}
err := InitMetaCache(context.Background(), s.qc, s.mgr)
s.NoError(err)
s.collectionName = "test_lb_policy"
s.loadCollection()
}
func (s *LBPolicySuite) loadCollection() {
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testFloatVecField: schemapb.DataType_FloatVector,
}
if enableMultipleVectorFields {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
schema := constructCollectionSchemaByDataType(s.collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
s.NoError(err)
ctx := context.Background()
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
CollectionName: s.collectionName,
DbName: dbName,
Schema: marshaledSchema,
ShardsNum: common.DefaultShardsNum,
},
ctx: ctx,
mixCoord: s.qc,
}
s.NoError(createColT.OnEnqueue())
s.NoError(createColT.PreExecute(ctx))
s.NoError(createColT.Execute(ctx))
s.NoError(createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, dbName, s.collectionName)
s.NoError(err)
status, err := s.qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
SourceID: paramtable.GetNodeID(),
},
CollectionID: collectionID,
})
s.NoError(err)
s.Equal(commonpb.ErrorCode_Success, status.ErrorCode)
s.collectionID = collectionID
}
func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(5, nil)
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
// shardLeaders: s.nodes,
nq: 1,
}, &typeutil.UniqueSet{})
s.NoError(err)
s.Equal(int64(5), targetNode.nodeID)
// test select node failed, then update shard leader cache and retry, expect success
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, errors.New("fake err")).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
// shardLeaders: s.nodes,
nq: 1,
}, &typeutil.UniqueSet{})
s.NoError(err)
s.Equal(int64(3), targetNode.nodeID)
// test select node always fails, expected failure
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
// shardLeaders: []nodeInfo{},
nq: 1,
}, &typeutil.UniqueSet{})
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test all nodes has been excluded, expected clear excludeNodes and try to select node again
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
// shardLeaders: s.nodes,
nq: 1,
}, &excludeNodes)
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test get shard leaders failed, retry to select node failed
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return nil, merr.ErrServiceUnavailable
}
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
// shardLeaders: s.nodes,
nq: 1,
}, &typeutil.UniqueSet{})
s.ErrorIs(err, merr.ErrServiceUnavailable)
}
func (s *LBPolicySuite) TestExecuteWithRetry() {
ctx := context.Background()
// test execute success
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test select node failed, expected error
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test get client failed, and retry failed, expected success
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(2)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.Error(err)
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Times(1)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test exec failed, then retry success
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
counter++
if counter == 1 {
return errors.New("fake error")
}
return nil
},
})
s.NoError(err)
// test exec timeout
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Times(1)
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.DeadlineExceeded)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
_, err := qn.Search(ctx, nil)
return err
},
})
s.True(merr.IsCanceledOrTimeout(err))
}
func (s *LBPolicySuite) TestExecuteOneChannel() {
ctx := context.Background()
mockErr := errors.New("mock error")
// test all channel success
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test get shard leader failed
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return nil, mockErr
}
err = s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.ErrorIs(err, mockErr)
}
func (s *LBPolicySuite) TestExecute() {
ctx := context.Background()
mockErr := errors.New("mock error")
// test all channel success
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test some channel failed
counter := atomic.NewInt64(0)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
// succeed in first execute
if counter.Add(1) == 1 {
return nil
}
return mockErr
},
})
s.Error(err)
s.Equal(int64(6), counter.Load())
// test get shard leader failed
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return nil, mockErr
}
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
nq: 1,
exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.ErrorIs(err, mockErr)
}
func (s *LBPolicySuite) TestUpdateCostMetrics() {
s.lbBalancer.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything)
s.lbPolicy.UpdateCostMetrics(1, &internalpb.CostAggregation{})
}
func (s *LBPolicySuite) TestNewLBPolicy() {
policy := NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
policy.Close()
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin")
policy = NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.RoundRobinBalancer")
policy.Close()
Params.Save(Params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside")
policy = NewLBPolicyImpl(s.mgr)
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*proxy.LookAsideBalancer")
policy.Close()
}
func (s *LBPolicySuite) TestGetShard() {
ctx := context.Background()
// ErrCollectionNotFullyLoaded is retriable, expected to retry until ctx done or success
counter := atomic.NewInt64(0)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
counter.Inc()
return nil, merr.ErrCollectionNotFullyLoaded
}
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
log.Info("return rpc success")
return nil, nil
}
_, err := s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true)
s.NoError(err)
s.Equal(int64(0), counter.Load())
// ErrServiceUnavailable is not retriable, expected to fail fast
counter.Store(0)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
counter.Inc()
return nil, merr.ErrCollectionNotLoaded
}
_, err = s.lbPolicy.GetShard(ctx, dbName, s.collectionName, s.collectionID, s.channels[0], true)
log.Info("check err", zap.Error(err))
s.Error(err)
s.Equal(int64(1), counter.Load())
}
func (s *LBPolicySuite) TestSelectNodeEdgeCases() {
ctx := context.Background()
// Test case 1: Empty shard leaders after refresh, should fail gracefully
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1)
// Setup mock to return empty shard leaders
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{}, // Empty node list
NodeAddrs: []string{},
Serviceable: []bool{},
},
},
}, nil
}
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
_, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.Error(err)
log.Info("test case 1")
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
// Test case 2: Single replica scenario - exclude it, refresh shows same single replica, should clear and succeed
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1)
singleNodeList := []int64{1}
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: singleNodeList,
NodeAddrs: []string{"localhost:9000"},
Serviceable: []bool{true},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude the single node
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(1), targetNode.nodeID)
s.Equal(0, excludeNodes.Len()) // Should be cleared
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
mixedNodeIDs := []int64{1, 2, 3}
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(-1, merr.ErrNodeNotAvailable).Times(1)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(3, nil).Times(1)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: mixedNodeIDs,
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, false, true},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude node 1, node 3 should be available
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(3), targetNode.nodeID)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
}
func (s *LBPolicySuite) TestGetShardLeaderList() {
ctx := context.Background()
// Test normal scenario with cache
channelList, err := s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true)
s.NoError(err)
s.Equal(len(s.channels), len(channelList))
s.Contains(channelList, s.channels[0])
s.Contains(channelList, s.channels[1])
// Test without cache - should refresh from coordinator
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
channelList, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, false)
s.NoError(err)
s.Equal(len(s.channels), len(channelList))
// Test error case - collection not loaded
counter := atomic.NewInt64(0)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
counter.Inc()
return nil, merr.ErrCollectionNotLoaded
}
_, err = s.lbPolicy.GetShardLeaderList(ctx, dbName, s.collectionName, s.collectionID, true)
s.Error(err)
s.ErrorIs(err, merr.ErrCollectionNotLoaded)
s.Equal(int64(1), counter.Load())
}
func (s *LBPolicySuite) TestSelectNodeWithExcludeClearing() {
ctx := context.Background()
// Test exclude nodes clearing when all replicas are excluded after cache refresh
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
// First attempt fails due to no candidates
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(1, nil).Times(1)
// Setup mock to return only excluded nodes first, then same nodes for retry
successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{1, 2}, // All these will be excluded
NodeAddrs: []string{"localhost:9000", "localhost:9001"},
Serviceable: []bool{true, true},
},
},
}, nil
}
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
excludeNodes := typeutil.NewUniqueSet(int64(1), int64(2)) // Exclude all available nodes
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(1), targetNode.nodeID)
s.Equal(0, excludeNodes.Len()) // Should be cleared when all replicas were excluded
// Test exclude nodes NOT cleared when only partial replicas are excluded
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(2, nil).Times(1)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{1, 2, 3}, // Node 2 and 3 are still available
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Only exclude node 1
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(2), targetNode.nodeID)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
// Test empty shard leaders scenario
s.lbBalancer.ExpectedCalls = nil
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
globalMetaCache.DeprecateShardCache(dbName, s.collectionName)
s.qc.(*MixCoordMock).GetShardLeadersFunc = func(ctx context.Context, req *querypb.GetShardLeadersRequest, opts ...grpc.CallOption) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &successStatus,
Shards: []*querypb.ShardLeadersList{
{
ChannelName: s.channels[0],
NodeIds: []int64{}, // Empty shard leaders
NodeAddrs: []string{},
Serviceable: []bool{},
},
},
}, nil
}
excludeNodes = typeutil.NewUniqueSet(int64(1))
_, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
db: dbName,
collectionName: s.collectionName,
collectionID: s.collectionID,
channel: s.channels[0],
nq: 1,
}, &excludeNodes)
s.Error(err)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared for empty shard leaders
}
func TestLBPolicySuite(t *testing.T) {
suite.Run(t, new(LBPolicySuite))
}

View File

@ -19,13 +19,11 @@ package proxy
import (
"context"
"fmt"
"math/rand"
"strconv"
"strings"
"sync"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -37,7 +35,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
@ -68,11 +65,11 @@ type Cache interface {
GetPartitionsIndex(ctx context.Context, database, collectionName string) ([]string, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, database, collectionName string) (*schemaInfo, error)
GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error)
GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error)
DeprecateShardCache(database, collectionName string)
InvalidateShardLeaderCache(collections []int64)
ListShardLocation() map[int64]nodeInfo
// GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error)
// GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error)
// DeprecateShardCache(database, collectionName string)
// InvalidateShardLeaderCache(collections []int64)
// ListShardLocation() map[int64]nodeInfo
RemoveCollection(ctx context.Context, database, collectionName string)
RemoveCollectionsByID(ctx context.Context, collectionID UniqueID, version uint64, removeVersion bool) []string
@ -288,58 +285,6 @@ func (info *collectionInfo) isCollectionCached() bool {
return info != nil && info.collID != UniqueID(0) && info.schema != nil
}
// shardLeaders wraps shard leader mapping for iteration.
type shardLeaders struct {
idx *atomic.Int64
collectionID int64
shardLeaders map[string][]nodeInfo
}
func (sl *shardLeaders) Get(channel string) []nodeInfo {
return sl.shardLeaders[channel]
}
func (sl *shardLeaders) GetShardLeaderList() []string {
return lo.Keys(sl.shardLeaders)
}
type shardLeadersReader struct {
leaders *shardLeaders
idx int64
}
// Shuffle returns the shuffled shard leader list.
func (it shardLeadersReader) Shuffle() map[string][]nodeInfo {
result := make(map[string][]nodeInfo)
for channel, leaders := range it.leaders.shardLeaders {
l := len(leaders)
// shuffle all replica at random order
shuffled := make([]nodeInfo, l)
for i, randIndex := range rand.Perm(l) {
shuffled[i] = leaders[randIndex]
}
// make each copy has same probability to be first replica
for index, leader := range shuffled {
if leader == leaders[int(it.idx)%l] {
shuffled[0], shuffled[index] = shuffled[index], shuffled[0]
}
}
result[channel] = shuffled
}
return result
}
// GetReader returns shuffer reader for shard leader.
func (sl *shardLeaders) GetReader() shardLeadersReader {
idx := sl.idx.Inc()
return shardLeadersReader{
leaders: sl,
idx: idx,
}
}
// make sure MetaCache implements Cache.
var _ Cache = (*MetaCache)(nil)
@ -347,18 +292,17 @@ var _ Cache = (*MetaCache)(nil)
type MetaCache struct {
mixCoord types.MixCoordClient
dbInfo map[string]*databaseInfo // database -> db_info
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
dbInfo map[string]*databaseInfo // database -> db_info
collInfo map[string]map[string]*collectionInfo // database -> collectionName -> collection_info
credMap map[string]*internalpb.CredentialInfo // cache for credential, lazy load
privilegeInfos map[string]struct{} // privileges cache
userToRoles map[string]map[string]struct{} // user to role cache
mu sync.RWMutex
credMut sync.RWMutex
leaderMut sync.RWMutex
shardMgr shardClientMgr
sfGlobal conc.Singleflight[*collectionInfo]
sfDB conc.Singleflight[*databaseInfo]
sfGlobal conc.Singleflight[*collectionInfo]
sfDB conc.Singleflight[*databaseInfo]
IDStart int64
IDCount int64
@ -372,9 +316,9 @@ type MetaCache struct {
var globalMetaCache Cache
// InitMetaCache initializes globalMetaCache
func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr shardClientMgr) error {
func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient) error {
var err error
globalMetaCache, err = NewMetaCache(mixCoord, shardMgr)
globalMetaCache, err = NewMetaCache(mixCoord)
if err != nil {
return err
}
@ -390,14 +334,12 @@ func InitMetaCache(ctx context.Context, mixCoord types.MixCoordClient, shardMgr
}
// NewMetaCache creates a MetaCache with provided RootCoord and QueryNode
func NewMetaCache(mixCoord types.MixCoordClient, shardMgr shardClientMgr) (*MetaCache, error) {
func NewMetaCache(mixCoord types.MixCoordClient) (*MetaCache, error) {
return &MetaCache{
mixCoord: mixCoord,
dbInfo: map[string]*databaseInfo{},
collInfo: map[string]map[string]*collectionInfo{},
collLeader: map[string]map[string]*shardLeaders{},
credMap: map[string]*internalpb.CredentialInfo{},
shardMgr: shardMgr,
privilegeInfos: map[string]struct{}{},
userToRoles: map[string]map[string]struct{}{},
collectionCacheVersion: make(map[UniqueID]uint64),
@ -898,192 +840,12 @@ func (m *MetaCache) RemoveCollectionsByID(ctx context.Context, collectionID Uniq
return collNames
}
func (m *MetaCache) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
method := "GetShard"
// check cache first
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
if cacheShardLeaders == nil || !withCache {
// refresh shard leader cache
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
cacheShardLeaders = newShardLeaders
}
return cacheShardLeaders.Get(channel), nil
}
func (m *MetaCache) GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) {
method := "GetShardLeaderList"
// check cache first
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
if cacheShardLeaders == nil || !withCache {
// refresh shard leader cache
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
cacheShardLeaders = newShardLeaders
}
return cacheShardLeaders.GetShardLeaderList(), nil
}
func (m *MetaCache) getCachedShardLeaders(database, collectionName, caller string) *shardLeaders {
m.leaderMut.RLock()
var cacheShardLeaders *shardLeaders
db, ok := m.collLeader[database]
if !ok {
cacheShardLeaders = nil
} else {
cacheShardLeaders = db[collectionName]
}
m.leaderMut.RUnlock()
if cacheShardLeaders != nil {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheHitLabel).Inc()
} else {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheMissLabel).Inc()
}
return cacheShardLeaders
}
func (m *MetaCache) updateShardLocationCache(ctx context.Context, database, collectionName string, collectionID int64) (*shardLeaders, error) {
log := log.Ctx(ctx).With(
zap.String("db", database),
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
method := "updateShardLocationCache"
tr := timerecord.NewTimeRecorder(method)
defer metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).
Observe(float64(tr.ElapseSpan().Milliseconds()))
req := &querypb.GetShardLeadersRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionID: collectionID,
WithUnserviceableShards: true,
}
resp, err := m.mixCoord.GetShardLeaders(ctx, req)
if err := merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
log.Error("failed to get shard locations",
zap.Int64("collectionID", collectionID),
zap.Error(err))
return nil, err
}
shards := parseShardLeaderList2QueryNode(resp.GetShards())
// convert shards map to string for logging
if log.Logger.Level() == zap.DebugLevel {
shardStr := make([]string, 0, len(shards))
for channel, nodes := range shards {
nodeStrs := make([]string, 0, len(nodes))
for _, node := range nodes {
nodeStrs = append(nodeStrs, node.String())
}
shardStr = append(shardStr, fmt.Sprintf("%s:[%s]", channel, strings.Join(nodeStrs, ", ")))
}
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
}
newShardLeaders := &shardLeaders{
collectionID: collectionID,
shardLeaders: shards,
idx: atomic.NewInt64(0),
}
m.leaderMut.Lock()
if _, ok := m.collLeader[database]; !ok {
m.collLeader[database] = make(map[string]*shardLeaders)
}
m.collLeader[database][collectionName] = newShardLeaders
m.leaderMut.Unlock()
return newShardLeaders, nil
}
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]nodeInfo {
shard2QueryNodes := make(map[string][]nodeInfo)
for _, leaders := range shardsLeaders {
qns := make([]nodeInfo, len(leaders.GetNodeIds()))
for j := range qns {
qns[j] = nodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j], leaders.GetServiceable()[j]}
}
shard2QueryNodes[leaders.GetChannelName()] = qns
}
return shard2QueryNodes
}
// used for Garbage collection shard client
func (m *MetaCache) ListShardLocation() map[int64]nodeInfo {
m.leaderMut.RLock()
defer m.leaderMut.RUnlock()
shardLeaderInfo := make(map[int64]nodeInfo)
for _, dbInfo := range m.collLeader {
for _, shardLeaders := range dbInfo {
for _, nodeInfos := range shardLeaders.shardLeaders {
for _, node := range nodeInfos {
shardLeaderInfo[node.nodeID] = node
}
}
}
}
return shardLeaderInfo
}
// DeprecateShardCache clear the shard leader cache of a collection
func (m *MetaCache) DeprecateShardCache(database, collectionName string) {
log.Info("deprecate shard cache for collection", zap.String("collectionName", collectionName))
m.leaderMut.Lock()
defer m.leaderMut.Unlock()
dbInfo, ok := m.collLeader[database]
if ok {
delete(dbInfo, collectionName)
if len(dbInfo) == 0 {
delete(m.collLeader, database)
}
}
}
// InvalidateShardLeaderCache called when Shard leader balance happened
func (m *MetaCache) InvalidateShardLeaderCache(collections []int64) {
log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections))
m.leaderMut.Lock()
defer m.leaderMut.Unlock()
collectionSet := typeutil.NewUniqueSet(collections...)
for dbName, dbInfo := range m.collLeader {
for collectionName, shardLeaders := range dbInfo {
if collectionSet.Contain(shardLeaders.collectionID) {
delete(dbInfo, collectionName)
}
}
if len(dbInfo) == 0 {
delete(m.collLeader, dbName)
}
}
}
func (m *MetaCache) RemoveDatabase(ctx context.Context, database string) {
log.Ctx(ctx).Debug("remove database", zap.String("name", database))
m.mu.Lock()
defer m.mu.Unlock()
delete(m.collInfo, database)
delete(m.dbInfo, database)
m.mu.Unlock()
m.leaderMut.Lock()
delete(m.collLeader, database)
m.leaderMut.Unlock()
}
func (m *MetaCache) HasDatabase(ctx context.Context, database string) bool {

View File

@ -27,8 +27,6 @@ import (
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
uatomic "go.uber.org/atomic"
"google.golang.org/grpc"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
@ -47,7 +45,6 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -966,8 +963,7 @@ func TestMetaCache_GetCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
id, err := globalMetaCache.GetCollectionID(ctx, dbName, "collection1")
@ -1019,8 +1015,7 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
// should be no data race.
@ -1053,8 +1048,7 @@ func TestMetaCacheGetCollectionWithUpdate(t *testing.T) {
ctx := context.Background()
rootCoord := mocks.NewMockMixCoordClient(t)
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil)
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
t.Run("update with name", func(t *testing.T) {
rootCoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{
@ -1131,8 +1125,7 @@ func TestMetaCache_InitCache(t *testing.T) {
rootCoord := mocks.NewMockMixCoordClient(t)
rootCoord.EXPECT().ShowLoadCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(&internalpb.ListPolicyResponse{Status: merr.Success()}, nil).Once()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
})
@ -1142,8 +1135,7 @@ func TestMetaCache_InitCache(t *testing.T) {
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(
&internalpb.ListPolicyResponse{Status: merr.Status(errors.New("mock list policy error"))},
nil).Once()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.Error(t, err)
})
@ -1152,8 +1144,7 @@ func TestMetaCache_InitCache(t *testing.T) {
rootCoord := mocks.NewMockMixCoordClient(t)
rootCoord.EXPECT().ListPolicy(mock.Anything, mock.Anything, mock.Anything).Return(
nil, errors.New("mock list policy rpc errorr")).Once()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.Error(t, err)
})
}
@ -1161,8 +1152,7 @@ func TestMetaCache_InitCache(t *testing.T) {
func TestMetaCache_GetCollectionName(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
collection, err := globalMetaCache.GetCollectionName(ctx, GetCurDBNameFromContextOrDefault(ctx), 1)
@ -1213,8 +1203,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) {
func TestMetaCache_GetCollectionFailure(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
rootCoord.Error = true
@ -1247,8 +1236,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) {
func TestMetaCache_GetNonExistCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
id, err := globalMetaCache.GetCollectionID(ctx, dbName, "collection3")
@ -1262,8 +1250,7 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) {
func TestMetaCache_GetPartitionID(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
id, err := globalMetaCache.GetPartitionID(ctx, dbName, "collection1", "par1")
@ -1283,8 +1270,7 @@ func TestMetaCache_GetPartitionID(t *testing.T) {
func TestMetaCache_ConcurrentTest1(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
var wg sync.WaitGroup
@ -1337,8 +1323,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) {
func TestMetaCache_GetPartitionError(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
// Test the case where ShowPartitionsResponse is not aligned
@ -1362,133 +1347,23 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
}
func TestMetaCache_GetShard(t *testing.T) {
var (
ctx = context.Background()
collectionName = "collection1"
collectionID = int64(1)
)
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
require.Nil(t, err)
t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShard(ctx, true, dbName, "non-exists", 0, "channel-1")
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
shards, err := globalMetaCache.GetShard(ctx, false, dbName, collectionName, collectionID, "channel-1")
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info", func(t *testing.T) {
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: merr.Success(),
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil
}
shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 3, len(shards))
// get from cache
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil
}
shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 3, len(shards))
})
t.Skip("GetShard has been moved to ShardClientMgr in shardclient package")
// Test body removed - functionality moved to shardclient package
}
func TestMetaCache_ClearShards(t *testing.T) {
var (
ctx = context.TODO()
collectionName = "collection1"
collectionID = int64(1)
)
qc := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, qc, mgr)
require.Nil(t, err)
t.Run("Clear with no collection info", func(t *testing.T) {
globalMetaCache.DeprecateShardCache(dbName, "collection_not_exist")
})
t.Run("Clear valid collection empty cache", func(t *testing.T) {
globalMetaCache.DeprecateShardCache(dbName, collectionName)
})
t.Run("Clear valid collection valid cache", func(t *testing.T) {
qc.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: merr.Success(),
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil
}
shards, err := globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
require.NoError(t, err)
require.NotEmpty(t, shards)
require.Equal(t, 3, len(shards))
globalMetaCache.DeprecateShardCache(dbName, collectionName)
qc.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "not implemented",
},
}, nil
}
shards, err = globalMetaCache.GetShard(ctx, true, dbName, collectionName, collectionID, "channel-1")
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Skip("DeprecateShardCache has been moved to ShardClientMgr in shardclient package")
// Test body removed - functionality moved to shardclient package
}
func TestMetaCache_PolicyInfo(t *testing.T) {
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
t.Run("InitMetaCache", func(t *testing.T) {
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return nil, errors.New("mock error")
}
err := InitMetaCache(context.Background(), client, mgr)
err := InitMetaCache(context.Background(), client)
assert.Error(t, err)
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
@ -1497,7 +1372,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
PolicyInfos: []string{"policy1", "policy2", "policy3"},
}, nil
}
err = InitMetaCache(context.Background(), client, mgr)
err = InitMetaCache(context.Background(), client)
assert.NoError(t, err)
})
@ -1509,7 +1384,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2")},
}, nil
}
err := InitMetaCache(context.Background(), client, mgr)
err := InitMetaCache(context.Background(), client)
assert.NoError(t, err)
policyInfos := privilege.GetPrivilegeCache().GetPrivilegeInfo(context.Background())
assert.Equal(t, 3, len(policyInfos))
@ -1525,7 +1400,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2")},
}, nil
}
err := InitMetaCache(context.Background(), client, mgr)
err := InitMetaCache(context.Background(), client)
assert.NoError(t, err)
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheGrantPrivilege, OpKey: "policyX"})
@ -1566,7 +1441,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
UserRoles: []string{funcutil.EncodeUserRoleCache("foo", "role1"), funcutil.EncodeUserRoleCache("foo", "role2"), funcutil.EncodeUserRoleCache("foo2", "role2"), funcutil.EncodeUserRoleCache("foo2", "role3")},
}, nil
}
err := InitMetaCache(context.Background(), client, mgr)
err := InitMetaCache(context.Background(), client)
assert.NoError(t, err)
err = privilege.GetPrivilegeCache().RefreshPolicyInfo(typeutil.CacheOp{OpType: typeutil.CacheDeleteUser, OpKey: "foo"})
@ -1601,8 +1476,7 @@ func TestMetaCache_PolicyInfo(t *testing.T) {
func TestMetaCache_RemoveCollection(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
shardMgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, shardMgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
rootCoord.showLoadCollections = func(ctx context.Context, in *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
@ -1646,48 +1520,14 @@ func TestMetaCache_RemoveCollection(t *testing.T) {
}
func TestGlobalMetaCache_ShuffleShardLeaders(t *testing.T) {
shards := map[string][]nodeInfo{
"channel-1": {
{
nodeID: 1,
address: "localhost:9000",
},
{
nodeID: 2,
address: "localhost:9000",
},
{
nodeID: 3,
address: "localhost:9000",
},
},
}
sl := &shardLeaders{
idx: uatomic.NewInt64(5),
shardLeaders: shards,
}
reader := sl.GetReader()
result := reader.Shuffle()
assert.Len(t, result["channel-1"], 3)
assert.Equal(t, int64(1), result["channel-1"][0].nodeID)
reader = sl.GetReader()
result = reader.Shuffle()
assert.Len(t, result["channel-1"], 3)
assert.Equal(t, int64(2), result["channel-1"][0].nodeID)
reader = sl.GetReader()
result = reader.Shuffle()
assert.Len(t, result["channel-1"], 3)
assert.Equal(t, int64(3), result["channel-1"][0].nodeID)
t.Skip("shardLeaders and nodeInfo have been moved to shardclient package")
// Test body removed - functionality moved to shardclient package
}
func TestMetaCache_Database(t *testing.T) {
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
shardMgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, shardMgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
@ -1703,8 +1543,7 @@ func TestGetDatabaseInfo(t *testing.T) {
t.Run("success", func(t *testing.T) {
ctx := context.Background()
rootCoord := mocks.NewMockMixCoordClient(t)
shardMgr := newShardClientMgr()
cache, err := NewMetaCache(rootCoord, shardMgr)
cache, err := NewMetaCache(rootCoord)
assert.NoError(t, err)
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
@ -1728,8 +1567,7 @@ func TestGetDatabaseInfo(t *testing.T) {
t.Run("error", func(t *testing.T) {
ctx := context.Background()
rootCoord := mocks.NewMockMixCoordClient(t)
shardMgr := newShardClientMgr()
cache, err := NewMetaCache(rootCoord, shardMgr)
cache, err := NewMetaCache(rootCoord)
assert.NoError(t, err)
rootCoord.EXPECT().DescribeDatabase(mock.Anything, mock.Anything).Return(&rootcoordpb.DescribeDatabaseResponse{
@ -1742,7 +1580,6 @@ func TestGetDatabaseInfo(t *testing.T) {
func TestMetaCache_AllocID(t *testing.T) {
ctx := context.Background()
shardMgr := newShardClientMgr()
t.Run("success", func(t *testing.T) {
rootCoord := mocks.NewMockMixCoordClient(t)
@ -1756,7 +1593,7 @@ func TestMetaCache_AllocID(t *testing.T) {
PolicyInfos: []string{"policy1", "policy2", "policy3"},
}, nil)
err := InitMetaCache(ctx, rootCoord, shardMgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
@ -1775,7 +1612,7 @@ func TestMetaCache_AllocID(t *testing.T) {
PolicyInfos: []string{"policy1", "policy2", "policy3"},
}, nil)
err := InitMetaCache(ctx, rootCoord, shardMgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
@ -1794,7 +1631,7 @@ func TestMetaCache_AllocID(t *testing.T) {
PolicyInfos: []string{"policy1", "policy2", "policy3"},
}, nil)
err := InitMetaCache(ctx, rootCoord, shardMgr)
err := InitMetaCache(ctx, rootCoord)
assert.NoError(t, err)
assert.Equal(t, globalMetaCache.HasDatabase(ctx, dbName), false)
@ -1805,52 +1642,8 @@ func TestMetaCache_AllocID(t *testing.T) {
}
func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) {
paramtable.Init()
paramtable.Get().Save(Params.ProxyCfg.ShardLeaderCacheInterval.Key, "1")
ctx := context.Background()
rootCoord := &MockMixCoordClientInterface{}
shardMgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, shardMgr)
assert.NoError(t, err)
rootCoord.showLoadCollections = func(ctx context.Context, in *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
return &querypb.ShowCollectionsResponse{
Status: merr.Success(),
CollectionIDs: []UniqueID{1},
InMemoryPercentages: []int64{100},
}, nil
}
called := uatomic.NewInt32(0)
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
called.Inc()
return &querypb.GetShardLeadersResponse{
Status: merr.Success(),
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
},
}, nil
}
nodeInfos, err := globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
assert.NoError(t, err)
assert.Len(t, nodeInfos, 3)
assert.Equal(t, called.Load(), int32(1))
globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
assert.Equal(t, called.Load(), int32(1))
globalMetaCache.InvalidateShardLeaderCache([]int64{1})
nodeInfos, err = globalMetaCache.GetShard(ctx, true, dbName, "collection1", 1, "channel-1")
assert.NoError(t, err)
assert.Len(t, nodeInfos, 3)
assert.Equal(t, called.Load(), int32(2))
t.Skip("GetShard and InvalidateShardLeaderCache have been moved to ShardClientMgr in shardclient package")
// Test body removed - functionality moved to shardclient package
}
func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) {
@ -2192,8 +1985,7 @@ func TestMetaCache_Parallel(t *testing.T) {
rootCoord.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&milvuspb.ShowPartitionsResponse{
Status: merr.Success(),
}, nil).Maybe()
mgr := newShardClientMgr()
cache, err := NewMetaCache(rootCoord, mgr)
cache, err := NewMetaCache(rootCoord)
assert.NoError(t, err)
cacheVersion := uint64(100)
@ -2242,84 +2034,6 @@ func TestMetaCache_Parallel(t *testing.T) {
}
func TestMetaCache_GetShardLeaderList(t *testing.T) {
var (
ctx = context.Background()
collectionName = "collection1"
collectionID = int64(1)
)
rootCoord := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rootCoord, mgr)
require.Nil(t, err)
t.Run("No collection in meta cache", func(t *testing.T) {
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, "non-exists", 0, true)
assert.Error(t, err)
assert.Empty(t, channels)
})
t.Run("Get channel list without cache", func(t *testing.T) {
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: merr.Success(),
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
Serviceable: []bool{true, true, true},
},
{
ChannelName: "channel-2",
NodeIds: []int64{4, 5, 6},
NodeAddrs: []string{"localhost:9003", "localhost:9004", "localhost:9005"},
Serviceable: []bool{true, true, true},
},
},
}, nil
}
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, false)
assert.NoError(t, err)
assert.Equal(t, 2, len(channels))
assert.Contains(t, channels, "channel-1")
assert.Contains(t, channels, "channel-2")
})
t.Run("Get channel list with cache", func(t *testing.T) {
// First call should populate cache
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
assert.NoError(t, err)
assert.Equal(t, 2, len(channels))
// Mock should return error but cache should be used
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "should not be called when using cache",
},
}, nil
}
channels, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
assert.NoError(t, err)
assert.Equal(t, 2, len(channels))
assert.Contains(t, channels, "channel-1")
assert.Contains(t, channels, "channel-2")
})
t.Run("Error from coordinator", func(t *testing.T) {
// Deprecate cache first
globalMetaCache.DeprecateShardCache(dbName, collectionName)
rootCoord.getShardLeaders = func(ctx context.Context, in *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) {
return nil, errors.New("coordinator error")
}
channels, err := globalMetaCache.GetShardLeaderList(ctx, dbName, collectionName, collectionID, true)
assert.Error(t, err)
assert.Empty(t, channels)
})
t.Skip("GetShardLeaderList has been moved to ShardClientMgr in shardclient package")
// Test body removed - functionality moved to shardclient package
}

View File

@ -55,8 +55,7 @@ func InitEmptyGlobalCache() {
emptyMock := common.NewEmptyMockT()
mixcoord := mocks.NewMockMixCoordClient(emptyMock)
mixcoord.EXPECT().DescribeCollection(mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("collection not found"))
mgr := newShardClientMgr()
globalMetaCache, err = NewMetaCache(mixcoord, mgr)
globalMetaCache, err = NewMetaCache(mixcoord)
if err != nil {
panic(err)
}

View File

@ -5,10 +5,7 @@ package proxy
import (
context "context"
internalpb "github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
mock "github.com/stretchr/testify/mock"
typeutil "github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
// MockCache is an autogenerated mock type for the Cache type
@ -80,40 +77,6 @@ func (_c *MockCache_AllocID_Call) RunAndReturn(run func(context.Context) (int64,
return _c
}
// DeprecateShardCache provides a mock function with given fields: database, collectionName
func (_m *MockCache) DeprecateShardCache(database string, collectionName string) {
_m.Called(database, collectionName)
}
// MockCache_DeprecateShardCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeprecateShardCache'
type MockCache_DeprecateShardCache_Call struct {
*mock.Call
}
// DeprecateShardCache is a helper method to define mock.On call
// - database string
// - collectionName string
func (_e *MockCache_Expecter) DeprecateShardCache(database interface{}, collectionName interface{}) *MockCache_DeprecateShardCache_Call {
return &MockCache_DeprecateShardCache_Call{Call: _e.mock.On("DeprecateShardCache", database, collectionName)}
}
func (_c *MockCache_DeprecateShardCache_Call) Run(run func(database string, collectionName string)) *MockCache_DeprecateShardCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(string))
})
return _c
}
func (_c *MockCache_DeprecateShardCache_Call) Return() *MockCache_DeprecateShardCache_Call {
_c.Call.Return()
return _c
}
func (_c *MockCache_DeprecateShardCache_Call) RunAndReturn(run func(string, string)) *MockCache_DeprecateShardCache_Call {
_c.Run(run)
return _c
}
// GetCollectionID provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) GetCollectionID(ctx context.Context, database string, collectionName string) (int64, error) {
ret := _m.Called(ctx, database, collectionName)
@ -351,65 +314,6 @@ func (_c *MockCache_GetCollectionSchema_Call) RunAndReturn(run func(context.Cont
return _c
}
// GetCredentialInfo provides a mock function with given fields: ctx, username
func (_m *MockCache) GetCredentialInfo(ctx context.Context, username string) (*internalpb.CredentialInfo, error) {
ret := _m.Called(ctx, username)
if len(ret) == 0 {
panic("no return value specified for GetCredentialInfo")
}
var r0 *internalpb.CredentialInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string) (*internalpb.CredentialInfo, error)); ok {
return rf(ctx, username)
}
if rf, ok := ret.Get(0).(func(context.Context, string) *internalpb.CredentialInfo); ok {
r0 = rf(ctx, username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*internalpb.CredentialInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, username)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetCredentialInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCredentialInfo'
type MockCache_GetCredentialInfo_Call struct {
*mock.Call
}
// GetCredentialInfo is a helper method to define mock.On call
// - ctx context.Context
// - username string
func (_e *MockCache_Expecter) GetCredentialInfo(ctx interface{}, username interface{}) *MockCache_GetCredentialInfo_Call {
return &MockCache_GetCredentialInfo_Call{Call: _e.mock.On("GetCredentialInfo", ctx, username)}
}
func (_c *MockCache_GetCredentialInfo_Call) Run(run func(ctx context.Context, username string)) *MockCache_GetCredentialInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string))
})
return _c
}
func (_c *MockCache_GetCredentialInfo_Call) Return(_a0 *internalpb.CredentialInfo, _a1 error) *MockCache_GetCredentialInfo_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetCredentialInfo_Call) RunAndReturn(run func(context.Context, string) (*internalpb.CredentialInfo, error)) *MockCache_GetCredentialInfo_Call {
_c.Call.Return(run)
return _c
}
// GetDatabaseInfo provides a mock function with given fields: ctx, database
func (_m *MockCache) GetDatabaseInfo(ctx context.Context, database string) (*databaseInfo, error) {
ret := _m.Called(ctx, database)
@ -709,227 +613,6 @@ func (_c *MockCache_GetPartitionsIndex_Call) RunAndReturn(run func(context.Conte
return _c
}
// GetPrivilegeInfo provides a mock function with given fields: ctx
func (_m *MockCache) GetPrivilegeInfo(ctx context.Context) []string {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for GetPrivilegeInfo")
}
var r0 []string
if rf, ok := ret.Get(0).(func(context.Context) []string); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockCache_GetPrivilegeInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrivilegeInfo'
type MockCache_GetPrivilegeInfo_Call struct {
*mock.Call
}
// GetPrivilegeInfo is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockCache_Expecter) GetPrivilegeInfo(ctx interface{}) *MockCache_GetPrivilegeInfo_Call {
return &MockCache_GetPrivilegeInfo_Call{Call: _e.mock.On("GetPrivilegeInfo", ctx)}
}
func (_c *MockCache_GetPrivilegeInfo_Call) Run(run func(ctx context.Context)) *MockCache_GetPrivilegeInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockCache_GetPrivilegeInfo_Call) Return(_a0 []string) *MockCache_GetPrivilegeInfo_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCache_GetPrivilegeInfo_Call) RunAndReturn(run func(context.Context) []string) *MockCache_GetPrivilegeInfo_Call {
_c.Call.Return(run)
return _c
}
// GetShard provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID, channel
func (_m *MockCache) GetShard(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string) ([]nodeInfo, error) {
ret := _m.Called(ctx, withCache, database, collectionName, collectionID, channel)
if len(ret) == 0 {
panic("no return value specified for GetShard")
}
var r0 []nodeInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)); ok {
return rf(ctx, withCache, database, collectionName, collectionID, channel)
}
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) []nodeInfo); ok {
r0 = rf(ctx, withCache, database, collectionName, collectionID, channel)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]nodeInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64, string) error); ok {
r1 = rf(ctx, withCache, database, collectionName, collectionID, channel)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetShard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShard'
type MockCache_GetShard_Call struct {
*mock.Call
}
// GetShard is a helper method to define mock.On call
// - ctx context.Context
// - withCache bool
// - database string
// - collectionName string
// - collectionID int64
// - channel string
func (_e *MockCache_Expecter) GetShard(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}, channel interface{}) *MockCache_GetShard_Call {
return &MockCache_GetShard_Call{Call: _e.mock.On("GetShard", ctx, withCache, database, collectionName, collectionID, channel)}
}
func (_c *MockCache_GetShard_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string)) *MockCache_GetShard_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64), args[5].(string))
})
return _c
}
func (_c *MockCache_GetShard_Call) Return(_a0 []nodeInfo, _a1 error) *MockCache_GetShard_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetShard_Call) RunAndReturn(run func(context.Context, bool, string, string, int64, string) ([]nodeInfo, error)) *MockCache_GetShard_Call {
_c.Call.Return(run)
return _c
}
// GetShardLeaderList provides a mock function with given fields: ctx, database, collectionName, collectionID, withCache
func (_m *MockCache) GetShardLeaderList(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool) ([]string, error) {
ret := _m.Called(ctx, database, collectionName, collectionID, withCache)
if len(ret) == 0 {
panic("no return value specified for GetShardLeaderList")
}
var r0 []string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) ([]string, error)); ok {
return rf(ctx, database, collectionName, collectionID, withCache)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) []string); ok {
r0 = rf(ctx, database, collectionName, collectionID, withCache)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, bool) error); ok {
r1 = rf(ctx, database, collectionName, collectionID, withCache)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockCache_GetShardLeaderList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShardLeaderList'
type MockCache_GetShardLeaderList_Call struct {
*mock.Call
}
// GetShardLeaderList is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
// - collectionID int64
// - withCache bool
func (_e *MockCache_Expecter) GetShardLeaderList(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}, withCache interface{}) *MockCache_GetShardLeaderList_Call {
return &MockCache_GetShardLeaderList_Call{Call: _e.mock.On("GetShardLeaderList", ctx, database, collectionName, collectionID, withCache)}
}
func (_c *MockCache_GetShardLeaderList_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool)) *MockCache_GetShardLeaderList_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64), args[4].(bool))
})
return _c
}
func (_c *MockCache_GetShardLeaderList_Call) Return(_a0 []string, _a1 error) *MockCache_GetShardLeaderList_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockCache_GetShardLeaderList_Call) RunAndReturn(run func(context.Context, string, string, int64, bool) ([]string, error)) *MockCache_GetShardLeaderList_Call {
_c.Call.Return(run)
return _c
}
// GetUserRole provides a mock function with given fields: username
func (_m *MockCache) GetUserRole(username string) []string {
ret := _m.Called(username)
if len(ret) == 0 {
panic("no return value specified for GetUserRole")
}
var r0 []string
if rf, ok := ret.Get(0).(func(string) []string); ok {
r0 = rf(username)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
return r0
}
// MockCache_GetUserRole_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUserRole'
type MockCache_GetUserRole_Call struct {
*mock.Call
}
// GetUserRole is a helper method to define mock.On call
// - username string
func (_e *MockCache_Expecter) GetUserRole(username interface{}) *MockCache_GetUserRole_Call {
return &MockCache_GetUserRole_Call{Call: _e.mock.On("GetUserRole", username)}
}
func (_c *MockCache_GetUserRole_Call) Run(run func(username string)) *MockCache_GetUserRole_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockCache_GetUserRole_Call) Return(_a0 []string) *MockCache_GetUserRole_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCache_GetUserRole_Call) RunAndReturn(run func(string) []string) *MockCache_GetUserRole_Call {
_c.Call.Return(run)
return _c
}
// HasDatabase provides a mock function with given fields: ctx, database
func (_m *MockCache) HasDatabase(ctx context.Context, database string) bool {
ret := _m.Called(ctx, database)
@ -977,166 +660,6 @@ func (_c *MockCache_HasDatabase_Call) RunAndReturn(run func(context.Context, str
return _c
}
// InitPolicyInfo provides a mock function with given fields: info, userRoles
func (_m *MockCache) InitPolicyInfo(info []string, userRoles []string) {
_m.Called(info, userRoles)
}
// MockCache_InitPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InitPolicyInfo'
type MockCache_InitPolicyInfo_Call struct {
*mock.Call
}
// InitPolicyInfo is a helper method to define mock.On call
// - info []string
// - userRoles []string
func (_e *MockCache_Expecter) InitPolicyInfo(info interface{}, userRoles interface{}) *MockCache_InitPolicyInfo_Call {
return &MockCache_InitPolicyInfo_Call{Call: _e.mock.On("InitPolicyInfo", info, userRoles)}
}
func (_c *MockCache_InitPolicyInfo_Call) Run(run func(info []string, userRoles []string)) *MockCache_InitPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]string), args[1].([]string))
})
return _c
}
func (_c *MockCache_InitPolicyInfo_Call) Return() *MockCache_InitPolicyInfo_Call {
_c.Call.Return()
return _c
}
func (_c *MockCache_InitPolicyInfo_Call) RunAndReturn(run func([]string, []string)) *MockCache_InitPolicyInfo_Call {
_c.Run(run)
return _c
}
// InvalidateShardLeaderCache provides a mock function with given fields: collections
func (_m *MockCache) InvalidateShardLeaderCache(collections []int64) {
_m.Called(collections)
}
// MockCache_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
type MockCache_InvalidateShardLeaderCache_Call struct {
*mock.Call
}
// InvalidateShardLeaderCache is a helper method to define mock.On call
// - collections []int64
func (_e *MockCache_Expecter) InvalidateShardLeaderCache(collections interface{}) *MockCache_InvalidateShardLeaderCache_Call {
return &MockCache_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", collections)}
}
func (_c *MockCache_InvalidateShardLeaderCache_Call) Run(run func(collections []int64)) *MockCache_InvalidateShardLeaderCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]int64))
})
return _c
}
func (_c *MockCache_InvalidateShardLeaderCache_Call) Return() *MockCache_InvalidateShardLeaderCache_Call {
_c.Call.Return()
return _c
}
func (_c *MockCache_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int64)) *MockCache_InvalidateShardLeaderCache_Call {
_c.Run(run)
return _c
}
// ListShardLocation provides a mock function with no fields
func (_m *MockCache) ListShardLocation() map[int64]nodeInfo {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ListShardLocation")
}
var r0 map[int64]nodeInfo
if rf, ok := ret.Get(0).(func() map[int64]nodeInfo); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]nodeInfo)
}
}
return r0
}
// MockCache_ListShardLocation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListShardLocation'
type MockCache_ListShardLocation_Call struct {
*mock.Call
}
// ListShardLocation is a helper method to define mock.On call
func (_e *MockCache_Expecter) ListShardLocation() *MockCache_ListShardLocation_Call {
return &MockCache_ListShardLocation_Call{Call: _e.mock.On("ListShardLocation")}
}
func (_c *MockCache_ListShardLocation_Call) Run(run func()) *MockCache_ListShardLocation_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockCache_ListShardLocation_Call) Return(_a0 map[int64]nodeInfo) *MockCache_ListShardLocation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCache_ListShardLocation_Call) RunAndReturn(run func() map[int64]nodeInfo) *MockCache_ListShardLocation_Call {
_c.Call.Return(run)
return _c
}
// RefreshPolicyInfo provides a mock function with given fields: op
func (_m *MockCache) RefreshPolicyInfo(op typeutil.CacheOp) error {
ret := _m.Called(op)
if len(ret) == 0 {
panic("no return value specified for RefreshPolicyInfo")
}
var r0 error
if rf, ok := ret.Get(0).(func(typeutil.CacheOp) error); ok {
r0 = rf(op)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockCache_RefreshPolicyInfo_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfo'
type MockCache_RefreshPolicyInfo_Call struct {
*mock.Call
}
// RefreshPolicyInfo is a helper method to define mock.On call
// - op typeutil.CacheOp
func (_e *MockCache_Expecter) RefreshPolicyInfo(op interface{}) *MockCache_RefreshPolicyInfo_Call {
return &MockCache_RefreshPolicyInfo_Call{Call: _e.mock.On("RefreshPolicyInfo", op)}
}
func (_c *MockCache_RefreshPolicyInfo_Call) Run(run func(op typeutil.CacheOp)) *MockCache_RefreshPolicyInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(typeutil.CacheOp))
})
return _c
}
func (_c *MockCache_RefreshPolicyInfo_Call) Return(_a0 error) *MockCache_RefreshPolicyInfo_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockCache_RefreshPolicyInfo_Call) RunAndReturn(run func(typeutil.CacheOp) error) *MockCache_RefreshPolicyInfo_Call {
_c.Call.Return(run)
return _c
}
// RemoveCollection provides a mock function with given fields: ctx, database, collectionName
func (_m *MockCache) RemoveCollection(ctx context.Context, database string, collectionName string) {
_m.Called(ctx, database, collectionName)
@ -1223,39 +746,6 @@ func (_c *MockCache_RemoveCollectionsByID_Call) RunAndReturn(run func(context.Co
return _c
}
// RemoveCredential provides a mock function with given fields: username
func (_m *MockCache) RemoveCredential(username string) {
_m.Called(username)
}
// MockCache_RemoveCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCredential'
type MockCache_RemoveCredential_Call struct {
*mock.Call
}
// RemoveCredential is a helper method to define mock.On call
// - username string
func (_e *MockCache_Expecter) RemoveCredential(username interface{}) *MockCache_RemoveCredential_Call {
return &MockCache_RemoveCredential_Call{Call: _e.mock.On("RemoveCredential", username)}
}
func (_c *MockCache_RemoveCredential_Call) Run(run func(username string)) *MockCache_RemoveCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockCache_RemoveCredential_Call) Return() *MockCache_RemoveCredential_Call {
_c.Call.Return()
return _c
}
func (_c *MockCache_RemoveCredential_Call) RunAndReturn(run func(string)) *MockCache_RemoveCredential_Call {
_c.Run(run)
return _c
}
// RemoveDatabase provides a mock function with given fields: ctx, database
func (_m *MockCache) RemoveDatabase(ctx context.Context, database string) {
_m.Called(ctx, database)
@ -1290,39 +780,6 @@ func (_c *MockCache_RemoveDatabase_Call) RunAndReturn(run func(context.Context,
return _c
}
// UpdateCredential provides a mock function with given fields: credInfo
func (_m *MockCache) UpdateCredential(credInfo *internalpb.CredentialInfo) {
_m.Called(credInfo)
}
// MockCache_UpdateCredential_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredential'
type MockCache_UpdateCredential_Call struct {
*mock.Call
}
// UpdateCredential is a helper method to define mock.On call
// - credInfo *internalpb.CredentialInfo
func (_e *MockCache_Expecter) UpdateCredential(credInfo interface{}) *MockCache_UpdateCredential_Call {
return &MockCache_UpdateCredential_Call{Call: _e.mock.On("UpdateCredential", credInfo)}
}
func (_c *MockCache_UpdateCredential_Call) Run(run func(credInfo *internalpb.CredentialInfo)) *MockCache_UpdateCredential_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(*internalpb.CredentialInfo))
})
return _c
}
func (_c *MockCache_UpdateCredential_Call) Return() *MockCache_UpdateCredential_Call {
_c.Call.Return()
return _c
}
func (_c *MockCache_UpdateCredential_Call) RunAndReturn(run func(*internalpb.CredentialInfo)) *MockCache_UpdateCredential_Call {
_c.Run(run)
return _c
}
// NewMockCache creates a new instance of MockCache. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockCache(t interface {

View File

@ -1,193 +0,0 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package proxy
import (
context "context"
types "github.com/milvus-io/milvus/internal/types"
mock "github.com/stretchr/testify/mock"
)
// MockShardClientManager is an autogenerated mock type for the shardClientMgr type
type MockShardClientManager struct {
mock.Mock
}
type MockShardClientManager_Expecter struct {
mock *mock.Mock
}
func (_m *MockShardClientManager) EXPECT() *MockShardClientManager_Expecter {
return &MockShardClientManager_Expecter{mock: &_m.Mock}
}
// Close provides a mock function with no fields
func (_m *MockShardClientManager) Close() {
_m.Called()
}
// MockShardClientManager_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockShardClientManager_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockShardClientManager_Expecter) Close() *MockShardClientManager_Close_Call {
return &MockShardClientManager_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockShardClientManager_Close_Call) Run(run func()) *MockShardClientManager_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardClientManager_Close_Call) Return() *MockShardClientManager_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShardClientManager_Close_Call {
_c.Run(run)
return _c
}
// GetClient provides a mock function with given fields: ctx, nodeInfo1
func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeInfo1 nodeInfo) (types.QueryNodeClient, error) {
ret := _m.Called(ctx, nodeInfo1)
if len(ret) == 0 {
panic("no return value specified for GetClient")
}
var r0 types.QueryNodeClient
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) (types.QueryNodeClient, error)); ok {
return rf(ctx, nodeInfo1)
}
if rf, ok := ret.Get(0).(func(context.Context, nodeInfo) types.QueryNodeClient); ok {
r0 = rf(ctx, nodeInfo1)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(types.QueryNodeClient)
}
}
if rf, ok := ret.Get(1).(func(context.Context, nodeInfo) error); ok {
r1 = rf(ctx, nodeInfo1)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardClientManager_GetClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetClient'
type MockShardClientManager_GetClient_Call struct {
*mock.Call
}
// GetClient is a helper method to define mock.On call
// - ctx context.Context
// - nodeInfo1 nodeInfo
func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeInfo1 interface{}) *MockShardClientManager_GetClient_Call {
return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeInfo1)}
}
func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeInfo1 nodeInfo)) *MockShardClientManager_GetClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(nodeInfo))
})
return _c
}
func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClient, _a1 error) *MockShardClientManager_GetClient_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, nodeInfo) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call {
_c.Call.Return(run)
return _c
}
// SetClientCreatorFunc provides a mock function with given fields: creator
func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
_m.Called(creator)
}
// MockShardClientManager_SetClientCreatorFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetClientCreatorFunc'
type MockShardClientManager_SetClientCreatorFunc_Call struct {
*mock.Call
}
// SetClientCreatorFunc is a helper method to define mock.On call
// - creator queryNodeCreatorFunc
func (_e *MockShardClientManager_Expecter) SetClientCreatorFunc(creator interface{}) *MockShardClientManager_SetClientCreatorFunc_Call {
return &MockShardClientManager_SetClientCreatorFunc_Call{Call: _e.mock.On("SetClientCreatorFunc", creator)}
}
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Run(run func(creator queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(queryNodeCreatorFunc))
})
return _c
}
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Return() *MockShardClientManager_SetClientCreatorFunc_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run func(queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
_c.Run(run)
return _c
}
// Start provides a mock function with no fields
func (_m *MockShardClientManager) Start() {
_m.Called()
}
// MockShardClientManager_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
type MockShardClientManager_Start_Call struct {
*mock.Call
}
// Start is a helper method to define mock.On call
func (_e *MockShardClientManager_Expecter) Start() *MockShardClientManager_Start_Call {
return &MockShardClientManager_Start_Call{Call: _e.mock.On("Start")}
}
func (_c *MockShardClientManager_Start_Call) Run(run func()) *MockShardClientManager_Start_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardClientManager_Start_Call) Return() *MockShardClientManager_Start_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_Start_Call) RunAndReturn(run func()) *MockShardClientManager_Start_Call {
_c.Run(run)
return _c
}
// NewMockShardClientManager creates a new instance of MockShardClientManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockShardClientManager(t interface {
mock.TestingT
Cleanup(func())
}) *MockShardClientManager {
mock := &MockShardClientManager{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -119,7 +119,7 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) {
mix := NewMixCoordMock()
err := InitMetaCache(ctx, mix, nil)
err := InitMetaCache(ctx, mix)
assert.NoError(t, err)
idAllocator, err := allocator.NewIDAllocator(ctx, mix, paramtable.GetNodeID())

View File

@ -49,7 +49,6 @@ func TestPrivilegeInterceptor(t *testing.T) {
ctx = GetContext(context.Background(), "alice:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return &internalpb.ListPolicyResponse{
@ -80,7 +79,7 @@ func TestPrivilegeInterceptor(t *testing.T) {
})
assert.NoError(t, err)
err = InitMetaCache(ctx, client, mgr)
err = InitMetaCache(ctx, client)
assert.NoError(t, err)
_, err = PrivilegeInterceptor(ctx, &milvuspb.HasCollectionRequest{
DbName: "db_test",
@ -218,7 +217,6 @@ func TestResourceGroupPrivilege(t *testing.T) {
ctx = GetContext(context.Background(), "fooo:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return &internalpb.ListPolicyResponse{
@ -236,7 +234,7 @@ func TestResourceGroupPrivilege(t *testing.T) {
},
}, nil
}
InitMetaCache(ctx, client, mgr)
InitMetaCache(ctx, client)
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.CreateResourceGroupRequest{
ResourceGroup: "rg",
@ -274,7 +272,6 @@ func TestPrivilegeGroup(t *testing.T) {
var err error
ctx = GetContext(context.Background(), "fooo:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return &internalpb.ListPolicyResponse{
@ -287,7 +284,7 @@ func TestPrivilegeGroup(t *testing.T) {
},
}, nil
}
InitMetaCache(ctx, client, mgr)
InitMetaCache(ctx, client)
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
@ -331,7 +328,6 @@ func TestPrivilegeGroup(t *testing.T) {
var err error
ctx = GetContext(context.Background(), "fooo:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return &internalpb.ListPolicyResponse{
@ -344,7 +340,7 @@ func TestPrivilegeGroup(t *testing.T) {
},
}, nil
}
InitMetaCache(ctx, client, mgr)
InitMetaCache(ctx, client)
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
@ -388,7 +384,6 @@ func TestPrivilegeGroup(t *testing.T) {
var err error
ctx = GetContext(context.Background(), "fooo:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return &internalpb.ListPolicyResponse{
@ -401,7 +396,7 @@ func TestPrivilegeGroup(t *testing.T) {
},
}, nil
}
InitMetaCache(ctx, client, mgr)
InitMetaCache(ctx, client)
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
@ -491,7 +486,6 @@ func TestPrivilegeGroup(t *testing.T) {
var err error
ctx = GetContext(context.Background(), "fooo:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return &internalpb.ListPolicyResponse{
@ -504,7 +498,7 @@ func TestPrivilegeGroup(t *testing.T) {
},
}, nil
}
InitMetaCache(ctx, client, mgr)
InitMetaCache(ctx, client)
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{
@ -558,7 +552,6 @@ func TestPrivilegeGroup(t *testing.T) {
var err error
ctx = GetContext(context.Background(), "fooo:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
client.listPolicy = func(ctx context.Context, in *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) {
return &internalpb.ListPolicyResponse{
@ -571,7 +564,7 @@ func TestPrivilegeGroup(t *testing.T) {
},
}, nil
}
InitMetaCache(ctx, client, mgr)
InitMetaCache(ctx, client)
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.QueryRequest{})
@ -599,7 +592,6 @@ func TestBuiltinPrivilegeGroup(t *testing.T) {
var err error
ctx := GetContext(context.Background(), "fooo:123456")
client := &MockMixCoordClientInterface{}
mgr := newShardClientMgr()
policies := []string{}
for _, priv := range Params.RbacConfig.GetDefaultPrivilegeGroup("ClusterReadOnly").Privileges {
@ -615,7 +607,7 @@ func TestBuiltinPrivilegeGroup(t *testing.T) {
},
}, nil
}
InitMetaCache(ctx, client, mgr)
InitMetaCache(ctx, client)
defer privilege.CleanPrivilegeCache()
_, err = PrivilegeInterceptor(GetContext(context.Background(), "fooo:123456"), &milvuspb.SelectUserRequest{})

View File

@ -34,6 +34,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/proxy/connection"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/hookutil"
@ -96,7 +97,7 @@ type Proxy struct {
metricsCacheManager *metricsinfo.MetricsCacheManager
session *sessionutil.Session
shardMgr shardClientMgr
shardMgr shardclient.ShardClientMgr
searchResultCh chan *internalpb.SearchResults
@ -105,7 +106,7 @@ type Proxy struct {
closeCallbacks []func()
// for load balance in replicas
lbPolicy LBPolicy
lbPolicy shardclient.LBPolicy
// resource manager
resourceManager resource.Manager
@ -124,17 +125,14 @@ func NewProxy(ctx context.Context, _ dependency.Factory) (*Proxy, error) {
rand.Seed(time.Now().UnixNano())
ctx1, cancel := context.WithCancel(ctx)
n := 1024 // better to be configurable
mgr := newShardClientMgr()
lbPolicy := NewLBPolicyImpl(mgr)
lbPolicy.Start(ctx)
resourceManager := resource.NewManager(10*time.Second, 20*time.Second, make(map[string]time.Duration))
node := &Proxy{
ctx: ctx1,
cancel: cancel,
searchResultCh: make(chan *internalpb.SearchResults, n),
shardMgr: mgr,
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
lbPolicy: lbPolicy,
ctx: ctx1,
cancel: cancel,
searchResultCh: make(chan *internalpb.SearchResults, n),
// shardMgr: mgr,
simpleLimiter: NewSimpleLimiter(Params.QuotaConfig.AllocWaitInterval.GetAsDuration(time.Millisecond), Params.QuotaConfig.AllocRetryTimes.GetAsUint()),
// lbPolicy: lbPolicy,
resourceManager: resourceManager,
slowQueries: expirable.NewLRU[Timestamp, *metricsinfo.SlowQuery](20, nil, time.Minute*15),
}
@ -247,12 +245,15 @@ func (node *Proxy) Init() error {
node.metricsCacheManager = metricsinfo.NewMetricsCacheManager()
log.Debug("create metrics cache manager done", zap.String("role", typeutil.ProxyRole))
if err := InitMetaCache(node.ctx, node.mixCoord, node.shardMgr); err != nil {
if err := InitMetaCache(node.ctx, node.mixCoord); err != nil {
log.Warn("failed to init meta cache", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err
}
log.Debug("init meta cache done", zap.String("role", typeutil.ProxyRole))
node.shardMgr = shardclient.NewShardClientMgr(node.mixCoord)
node.lbPolicy = shardclient.NewLBPolicyImpl(node.shardMgr)
node.enableMaterializedView = Params.CommonCfg.EnableMaterializedView.GetAsBool()
// Enable internal rand pool for UUIDv4 generation
@ -273,6 +274,8 @@ func (node *Proxy) Start() error {
node.shardMgr.Start()
log.Debug("start shard client manager done", zap.String("role", typeutil.ProxyRole))
node.lbPolicy.Start(node.ctx)
if err := node.sched.Start(); err != nil {
log.Warn("failed to start task scheduler", zap.String("role", typeutil.ProxyRole), zap.Error(err))
return err

View File

@ -52,6 +52,7 @@ import (
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/componentutil"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -74,6 +75,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
_ "github.com/milvus-io/milvus/pkg/v2/util/symbolizer" // support symbolizer and crash dump
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -1037,7 +1039,11 @@ func TestProxy(t *testing.T) {
proxy.SetMixCoordClient(rootCoordClient)
log.Info("Proxy set mix coordinator client")
proxy.SetQueryNodeCreator(defaultQueryNodeClientCreator)
mockShardMgr := shardclient.NewMockShardClientManager(t)
mockShardMgr.EXPECT().SetClientCreatorFunc(mock.Anything).Return().Maybe()
proxy.shardMgr = mockShardMgr
proxy.SetQueryNodeCreator(shardclient.DefaultQueryNodeClientCreator)
log.Info("Proxy set query coordinator client")
proxy.UpdateStateCode(commonpb.StateCode_Initializing)

View File

@ -404,13 +404,14 @@ func (op *requeryOperator) requery(ctx context.Context, span trace.Span, ids *sc
PartitionIDs: op.partitionIDs, // use search partitionIDs
ConsistencyLevel: op.consistencyLevel,
},
request: queryReq,
plan: plan,
mixCoord: op.node.(*Proxy).mixCoord,
lb: op.node.(*Proxy).lbPolicy,
channelsMvcc: channelsMvcc,
fastSkip: true,
reQuery: true,
request: queryReq,
plan: plan,
mixCoord: op.node.(*Proxy).mixCoord,
lb: op.node.(*Proxy).lbPolicy,
shardclientMgr: op.node.(*Proxy).shardMgr,
channelsMvcc: channelsMvcc,
fastSkip: true,
reQuery: true,
}
queryResult, storageCost, err := op.node.(*Proxy).query(op.traceCtx, qt, span)
if err != nil {

View File

@ -25,7 +25,7 @@ func TestNewInterceptor(t *testing.T) {
mixCoord := mocks.NewMockMixCoordClient(t)
mixCoord.On("DescribeCollection", mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe()
var err error
globalMetaCache, err = NewMetaCache(mixCoord, nil)
globalMetaCache, err = NewMetaCache(mixCoord)
assert.NoError(t, err)
interceptor, err := NewInterceptor[*milvuspb.DescribeCollectionRequest, *milvuspb.DescribeCollectionResponse](node, "DescribeCollection")
assert.NoError(t, err)

View File

@ -1,160 +0,0 @@
package proxy
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestShardClientMgr(t *testing.T) {
ctx := context.Background()
nodeInfo := nodeInfo{
nodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil)
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
mgr := newShardClientMgr()
mgr.SetClientCreatorFunc(creator)
_, err := mgr.GetClient(ctx, nodeInfo)
assert.Nil(t, err)
mgr.Close()
assert.Equal(t, mgr.clients.Len(), 0)
}
func TestShardClient(t *testing.T) {
nodeInfo := nodeInfo{
nodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
shardClient := newShardClient(nodeInfo, creator, 3*time.Second)
assert.Equal(t, len(shardClient.clients), 0)
assert.Equal(t, false, shardClient.initialized.Load())
assert.Equal(t, false, shardClient.isClosed)
ctx := context.Background()
_, err := shardClient.getClient(ctx)
assert.Nil(t, err)
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
// test close
closed := shardClient.Close(false)
assert.False(t, closed)
closed = shardClient.Close(true)
assert.True(t, closed)
}
func TestPurgeClient(t *testing.T) {
node := nodeInfo{
nodeID: 1,
}
returnEmptyResult := atomic.NewBool(false)
cache := NewMockCache(t)
cache.EXPECT().ListShardLocation().RunAndReturn(func() map[int64]nodeInfo {
if returnEmptyResult.Load() {
return map[int64]nodeInfo{}
}
return map[int64]nodeInfo{
1: node,
}
})
globalMetaCache = cache
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
s := &shardClientMgrImpl{
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
clientCreator: creator,
closeCh: make(chan struct{}),
purgeInterval: 1 * time.Second,
expiredDuration: 3 * time.Second,
}
go s.PurgeClient()
defer s.Close()
_, err := s.GetClient(context.Background(), node)
assert.Nil(t, err)
qnClient, ok := s.clients.Get(1)
assert.True(t, ok)
assert.True(t, qnClient.lastActiveTs.Load() > 0)
time.Sleep(2 * time.Second)
// expected client should not been purged before expiredDuration
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds())
_, err = s.GetClient(context.Background(), node)
assert.Nil(t, err)
time.Sleep(2 * time.Second)
// GetClient should refresh lastActiveTs, expected client should not be purged
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds())
time.Sleep(2 * time.Second)
// client reach the expiredDuration, expected client should not be purged
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds())
returnEmptyResult.Store(true)
time.Sleep(2 * time.Second)
// remove client from shard location, expected client should be purged
assert.Equal(t, s.clients.Len(), 0)
}
func BenchmarkShardClientMgr(b *testing.B) {
node := nodeInfo{
nodeID: 1,
}
cache := NewMockCache(b)
cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{
1: node,
}).Maybe()
globalMetaCache = cache
qn := mocks.NewMockQueryNodeClient(b)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
s := &shardClientMgrImpl{
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
clientCreator: creator,
closeCh: make(chan struct{}),
purgeInterval: 1 * time.Second,
expiredDuration: 10 * time.Second,
}
go s.PurgeClient()
defer s.Close()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := s.GetClient(context.Background(), node)
assert.Nil(b, err)
}
})
}

View File

@ -0,0 +1,8 @@
# order by contributions
reviewers:
- weiliu1031
- congqixia
- czs007
approvers:
- maintainers

View File

@ -0,0 +1,351 @@
# ShardClient Package
The `shardclient` package provides client-side connection management and load balancing for communicating with QueryNode shards in the Milvus distributed architecture. It manages QueryNode client connections, caches shard leader information, and implements intelligent request routing strategies.
## Overview
In Milvus, collections are divided into shards (channels), and each shard has multiple replicas distributed across different QueryNodes for high availability and load balancing. The `shardclient` package is responsible for:
1. **Connection Management**: Maintaining a pool of gRPC connections to QueryNodes with automatic lifecycle management
2. **Shard Leader Cache**: Caching the mapping of shards to their leader QueryNodes to reduce coordination overhead
3. **Load Balancing**: Distributing requests across available QueryNode replicas using configurable policies
4. **Fault Tolerance**: Automatic retry and failover when QueryNodes become unavailable
## Architecture
```
┌──────────────────────────────────────────────────────────────┐
│ Proxy Layer │
│ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ ShardClientMgr │ │
│ │ • Shard leader cache (database → collection → shards) │
│ │ • QueryNode client pool management │
│ │ • Client lifecycle (init, purge, close) │
│ └───────────────────────┬──────────────────────────────┘ │
│ │ │
│ ┌───────────────────────▼──────────────────────────────┐ │
│ │ LBPolicy │ │
│ │ • Execute workload on collection/channels │ │
│ │ • Retry logic with replica failover │ │
│ │ • Node selection via balancer │ │
│ └───────────────────────┬──────────────────────────────┘ │
│ │ │
│ ┌────────────────┴────────────────┐ │
│ │ │ │
│ ┌──────▼────────┐ ┌─────────▼──────────┐ │
│ │ RoundRobin │ │ LookAsideBalancer │ │
│ │ Balancer │ │ • Cost-based │ │
│ │ │ │ • Health check │ │
│ └───────────────┘ └────────────────────┘ │
│ │ │
│ ┌───────────────────────▼──────────────────────────────┐ │
│ │ shardClient (per QueryNode) │ │
│ │ • Connection pool (configurable size) │ │
│ │ • Round-robin client selection │ │
│ │ • Lazy initialization and expiration │ │
│ └──────────────────────────────────────────────────────┘ │
└─────────────────────┬────────────────────────────────────────┘
│ gRPC
┌───────────────┴───────────────┐
│ │
┌─────▼─────┐ ┌──────▼──────┐
│ QueryNode │ │ QueryNode │
│ (1) │ │ (2) │
└───────────┘ └─────────────┘
```
## Core Components
### 1. ShardClientMgr
The central manager for QueryNode client connections and shard leader information.
**File**: `manager.go`
**Key Responsibilities**:
- Cache shard leader mappings from QueryCoord (`database → collectionName → channel → []nodeInfo`)
- Manage `shardClient` instances for each QueryNode
- Automatically purge expired clients (default: 60 minutes of inactivity)
- Invalidate cache when shard leaders change
**Interface**:
```go
type ShardClientMgr interface {
GetShard(ctx context.Context, withCache bool, database, collectionName string,
collectionID int64, channel string) ([]nodeInfo, error)
GetShardLeaderList(ctx context.Context, database, collectionName string,
collectionID int64, withCache bool) ([]string, error)
DeprecateShardCache(database, collectionName string)
InvalidateShardLeaderCache(collections []int64)
GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error)
Start()
Close()
}
```
**Configuration**:
- `purgeInterval`: Interval for checking expired clients (default: 600s)
- `expiredDuration`: Time after which inactive clients are purged (default: 60min)
### 2. shardClient
Manages a connection pool to a single QueryNode.
**File**: `shard_client.go`
**Features**:
- **Lazy initialization**: Connections are created on first use
- **Connection pooling**: Configurable pool size (`ProxyCfg.QueryNodePoolingSize`, default: 1)
- **Round-robin selection**: Distributes requests across pool connections
- **Expiration tracking**: Tracks last active time for automatic cleanup
- **Thread-safe**: Safe for concurrent access
**Lifecycle**:
1. Created when first request needs a QueryNode
2. Initializes connection pool on first `getClient()` call
3. Tracks `lastActiveTs` on each use
4. Closed by manager if expired or during shutdown
### 3. LBPolicy
Executes workloads on collections/channels with retry and failover logic.
**File**: `lb_policy.go`
**Key Methods**:
- **`Execute(ctx, CollectionWorkLoad)`**: Execute workload in parallel across all shards
- **`ExecuteOneChannel(ctx, CollectionWorkLoad)`**: Execute workload on any single shard (for lightweight operations)
- **`ExecuteWithRetry(ctx, ChannelWorkload)`**: Execute on specific channel with retry on different replicas
**Retry Strategy**:
- Retry up to `max(retryOnReplica, len(shardLeaders))` times
- Maintain `excludeNodes` set to avoid retrying failed nodes
- Refresh shard leader cache if initial attempt fails
- Clear `excludeNodes` if all replicas exhausted
**Workload Types**:
```go
type ChannelWorkload struct {
Db string
CollectionName string
CollectionID int64
Channel string
Nq int64 // Number of queries
Exec ExecuteFunc // Actual work to execute
}
type ExecuteFunc func(context.Context, UniqueID, types.QueryNodeClient, string) error
```
### 4. Load Balancers
Two strategies for selecting QueryNode replicas:
#### RoundRobinBalancer
**File**: `roundrobin_balancer.go`
Simple round-robin selection across available nodes. No state tracking, minimal overhead.
**Use case**: Uniform workload distribution when all nodes have similar capacity
#### LookAsideBalancer
**File**: `look_aside_balancer.go`
Cost-aware load balancer that considers QueryNode workload and health.
**Features**:
- **Cost metrics tracking**: Caches `CostAggregation` (response time, service time, total NQ) from QueryNodes
- **Workload score calculation**: Uses power-of-3 formula to prefer lightly loaded nodes:
```
score = executeSpeed + (1 + totalNQ + executingNQ)³ × serviceTime
```
- **Periodic health checks**: Monitors QueryNode health via `GetComponentStates` RPC
- **Unavailable node handling**: Marks nodes unreachable after consecutive health check failures
- **Adaptive behavior**: Falls back to round-robin when workload difference is small
**Configuration Parameters**:
- `ProxyCfg.CostMetricsExpireTime`: How long to trust cached cost metrics (default: varies)
- `ProxyCfg.CheckWorkloadRequestNum`: Check workload every N requests (default: varies)
- `ProxyCfg.WorkloadToleranceFactor`: Tolerance for workload difference before preferring lighter node
- `ProxyCfg.CheckQueryNodeHealthInterval`: Interval for health checks
- `ProxyCfg.HealthCheckTimeout`: Timeout for health check RPC
- `ProxyCfg.RetryTimesOnHealthCheck`: Failures before marking node unreachable
**Selection Strategy**:
```
if (requestCount % CheckWorkloadRequestNum == 0) {
// Cost-aware selection
select node with minimum workload score
if (maxScore - minScore) / minScore <= WorkloadToleranceFactor {
fall back to round-robin
}
} else {
// Fast path: round-robin
select next available node
}
```
## Configuration
Key configuration parameters from `paramtable`:
| Parameter | Path | Description | Default |
|-----------|------|-------------|---------|
| QueryNodePoolingSize | `ProxyCfg.QueryNodePoolingSize` | Size of connection pool per QueryNode | 1 |
| RetryTimesOnReplica | `ProxyCfg.RetryTimesOnReplica` | Max retry times on replica failures | varies |
| ReplicaSelectionPolicy | `ProxyCfg.ReplicaSelectionPolicy` | Load balancing policy: `round_robin` or `look_aside` | `look_aside` |
| CostMetricsExpireTime | `ProxyCfg.CostMetricsExpireTime` | Expiration time for cost metrics cache | varies |
| CheckWorkloadRequestNum | `ProxyCfg.CheckWorkloadRequestNum` | Frequency of workload-aware selection | varies |
| WorkloadToleranceFactor | `ProxyCfg.WorkloadToleranceFactor` | Tolerance for workload differences | varies |
| CheckQueryNodeHealthInterval | `ProxyCfg.CheckQueryNodeHealthInterval` | Health check interval | varies |
| HealthCheckTimeout | `ProxyCfg.HealthCheckTimeout` | Health check RPC timeout | varies |
## Usage Example
```go
import (
"context"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
)
// 1. Create ShardClientMgr with MixCoord client
mgr := shardclient.NewShardClientMgr(mixCoordClient)
mgr.Start() // Start background purge goroutine
defer mgr.Close()
// 2. Create LBPolicy
policy := shardclient.NewLBPolicyImpl(mgr)
policy.Start(ctx) // Start load balancer (health checks, etc.)
defer policy.Close()
// 3. Execute collection workload (e.g., search/query)
workload := shardclient.CollectionWorkLoad{
Db: "default",
CollectionName: "my_collection",
CollectionID: 12345,
Nq: 100, // Number of queries
Exec: func(ctx context.Context, nodeID int64, client types.QueryNodeClient, channel string) error {
// Perform actual work (search, query, etc.)
req := &querypb.SearchRequest{/* ... */}
resp, err := client.Search(ctx, req)
return err
},
}
// Execute on all channels in parallel
err := policy.Execute(ctx, workload)
// Or execute on any single channel (for lightweight ops)
err := policy.ExecuteOneChannel(ctx, workload)
```
## Cache Management
### Shard Leader Cache
The shard leader cache stores the mapping of shards to their leader QueryNodes:
```
database → collectionName → shardLeaders {
collectionID: int64
shardLeaders: map[channel][]nodeInfo
}
```
**Cache Operations**:
- **Hit**: When cached shard leaders are used (tracked via `ProxyCacheStatsCounter`)
- **Miss**: When cache lookup fails, triggers RPC to QueryCoord via `GetShardLeaders`
- **Invalidation**:
- `DeprecateShardCache(db, collection)`: Remove specific collection
- `InvalidateShardLeaderCache(collectionIDs)`: Remove collections by ID (called on shard leader changes)
- `RemoveDatabase(db)`: Remove entire database
### Client Purging
The `ShardClientMgr` periodically purges unused clients:
1. Every `purgeInterval` (default: 600s), iterate all cached clients
2. Check if client is still a shard leader (via `ListShardLocation()`)
3. If not a leader and expired (`lastActiveTs` > `expiredDuration`), close and remove
4. This prevents connection leaks when QueryNodes are removed or shards rebalance
## Error Handling
### Common Errors
- **`errClosed`**: Client is closed (returned when accessing closed `shardClient`)
- **`merr.ErrChannelNotAvailable`**: No available shard leaders for channel
- **`merr.ErrNodeNotAvailable`**: Selected node is not available
- **`merr.ErrCollectionNotLoaded`**: Collection is not loaded in QueryNodes
- **`merr.ErrServiceUnavailable`**: All available nodes are unreachable
### Retry Logic
Retry is handled at multiple levels:
1. **LBPolicy level**:
- Retries on different replicas when request fails
- Refreshes shard leader cache on failure
- Respects context cancellation
2. **Balancer level**:
- Tracks failed nodes and excludes them from selection
- Health checks recover nodes when they come back online
3. **gRPC level**:
- Connection-level retries handled by gRPC layer
## Metrics
The package exports several metrics:
- `ProxyCacheStatsCounter`: Shard leader cache hit/miss statistics
- Labels: `nodeID`, `method` (GetShard/GetShardLeaderList), `status` (hit/miss)
- `ProxyUpdateCacheLatency`: Latency of updating shard leader cache
- Labels: `nodeID`, `method`
## Testing
The package includes extensive test coverage:
- `shard_client_test.go`: Tests for connection pool management
- `manager_test.go`: Tests for cache management and client lifecycle
- `lb_policy_test.go`: Tests for retry logic and workload execution
- `roundrobin_balancer_test.go`: Tests for round-robin selection
- `look_aside_balancer_test.go`: Tests for cost-aware selection and health checks
**Mock interfaces** (via mockery):
- `mock_shardclient_manager.go`: Mock `ShardClientMgr`
- `mock_lb_policy.go`: Mock `LBPolicy`
- `mock_lb_balancer.go`: Mock `LBBalancer`
## Thread Safety
All components are designed for concurrent access:
- `shardClientMgrImpl`: Uses `sync.RWMutex` for cache, `typeutil.ConcurrentMap` for clients
- `shardClient`: Uses `sync.RWMutex` and atomic operations
- `LookAsideBalancer`: Uses `typeutil.ConcurrentMap` for all mutable state
- `RoundRobinBalancer`: Uses `atomic.Int64` for index
## Related Components
- **Proxy** (`internal/proxy/`): Uses `shardclient` to route search/query requests to QueryNodes
- **QueryCoord** (`internal/querycoordv2/`): Provides shard leader information via `GetShardLeaders` RPC
- **QueryNode** (`internal/querynodev2/`): Receives and processes requests routed by `shardclient`
- **Registry** (`internal/registry/`): Provides client creation functions for gRPC connections
## Future Improvements
Potential areas for enhancement:
1. **Adaptive pooling**: Dynamically adjust connection pool size based on load
2. **Circuit breaker**: Add circuit breaker pattern for consistently failing nodes
3. **Advanced metrics**: Export more detailed metrics (per-node latency, error rates, etc.)
4. **Smart caching**: Use TTL-based cache expiration instead of invalidation-only
5. **Connection warming**: Pre-establish connections to known QueryNodes

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package shardclient
import (
"context"
@ -23,7 +23,7 @@ import (
)
type LBBalancer interface {
RegisterNodeInfo(nodeInfos []nodeInfo)
RegisterNodeInfo(nodeInfos []NodeInfo)
SelectNode(ctx context.Context, availableNodes []int64, nq int64) (int64, error)
CancelWorkload(node int64, nq int64)
UpdateCostMetrics(node int64, cost *internalpb.CostAggregation)

View File

@ -13,7 +13,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package shardclient
import (
"context"
@ -30,27 +31,28 @@ import (
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type executeFunc func(context.Context, UniqueID, types.QueryNodeClient, string) error
type ExecuteFunc func(context.Context, UniqueID, types.QueryNodeClient, string) error
type ChannelWorkload struct {
db string
collectionName string
collectionID int64
channel string
nq int64
exec executeFunc
Db string
CollectionName string
CollectionID int64
Channel string
Nq int64
Exec ExecuteFunc
}
type CollectionWorkLoad struct {
db string
collectionName string
collectionID int64
nq int64
exec executeFunc
Db string
CollectionName string
CollectionID int64
Nq int64
Exec ExecuteFunc
}
type LBPolicy interface {
@ -69,12 +71,12 @@ const (
type LBPolicyImpl struct {
getBalancer func() LBBalancer
clientMgr shardClientMgr
clientMgr ShardClientMgr
balancerMap map[string]LBBalancer
retryOnReplica int
}
func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
func NewLBPolicyImpl(clientMgr ShardClientMgr) *LBPolicyImpl {
balancerMap := make(map[string]LBBalancer)
balancerMap[LookAside] = NewLookAsideBalancer(clientMgr)
balancerMap[RoundRobin] = NewRoundRobinBalancer()
@ -87,7 +89,7 @@ func NewLBPolicyImpl(clientMgr shardClientMgr) *LBPolicyImpl {
return balancerMap[balancePolicy]
}
retryOnReplica := Params.ProxyCfg.RetryTimesOnReplica.GetAsInt()
retryOnReplica := paramtable.Get().ProxyCfg.RetryTimesOnReplica.GetAsInt()
return &LBPolicyImpl{
getBalancer: getBalancer,
@ -105,11 +107,11 @@ func (lb *LBPolicyImpl) Start(ctx context.Context) {
// GetShard will retry until ctx done, except the collection is not loaded.
// return all replicas of shard from cache if withCache is true, otherwise return shard leaders from coord.
func (lb *LBPolicyImpl) GetShard(ctx context.Context, dbName string, collName string, collectionID int64, channel string, withCache bool) ([]nodeInfo, error) {
var shardLeaders []nodeInfo
func (lb *LBPolicyImpl) GetShard(ctx context.Context, dbName string, collName string, collectionID int64, channel string, withCache bool) ([]NodeInfo, error) {
var shardLeaders []NodeInfo
err := retry.Handle(ctx, func() (bool, error) {
var err error
shardLeaders, err = globalMetaCache.GetShard(ctx, withCache, dbName, collName, collectionID, channel)
shardLeaders, err = lb.clientMgr.GetShard(ctx, withCache, dbName, collName, collectionID, channel)
return !errors.Is(err, merr.ErrCollectionNotLoaded), err
})
return shardLeaders, err
@ -121,25 +123,25 @@ func (lb *LBPolicyImpl) GetShardLeaderList(ctx context.Context, dbName string, c
var ret []string
err := retry.Handle(ctx, func() (bool, error) {
var err error
ret, err = globalMetaCache.GetShardLeaderList(ctx, dbName, collName, collectionID, withCache)
ret, err = lb.clientMgr.GetShardLeaderList(ctx, dbName, collName, collectionID, withCache)
return !errors.Is(err, merr.ErrCollectionNotLoaded), err
})
return ret, err
}
// try to select the best node from the available nodes
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes *typeutil.UniqueSet) (nodeInfo, error) {
func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, workload ChannelWorkload, excludeNodes *typeutil.UniqueSet) (NodeInfo, error) {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("collectionID", workload.CollectionID),
zap.String("channelName", workload.Channel),
)
// Select node using specified nodes
trySelectNode := func(withCache bool) (nodeInfo, error) {
shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, withCache)
trySelectNode := func(withCache bool) (NodeInfo, error) {
shardLeaders, err := lb.GetShard(ctx, workload.Db, workload.CollectionName, workload.CollectionID, workload.Channel, withCache)
if err != nil {
log.Warn("failed to get shard delegator",
zap.Error(err))
return nodeInfo{}, err
return NodeInfo{}, err
}
// if all available delegator has been excluded even after refresh shard leader cache
@ -147,7 +149,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
if !withCache && len(shardLeaders) > 0 && len(shardLeaders) <= excludeNodes.Len() {
allReplicaExcluded := true
for _, node := range shardLeaders {
if !excludeNodes.Contain(node.nodeID) {
if !excludeNodes.Contain(node.NodeID) {
allReplicaExcluded = false
break
}
@ -158,14 +160,14 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
}
}
candidateNodes := make(map[int64]nodeInfo)
serviceableNodes := make(map[int64]nodeInfo)
candidateNodes := make(map[int64]NodeInfo)
serviceableNodes := make(map[int64]NodeInfo)
defer func() {
if err != nil {
candidatesInStr := lo.Map(shardLeaders, func(node nodeInfo, _ int) string {
candidatesInStr := lo.Map(shardLeaders, func(node NodeInfo, _ int) string {
return node.String()
})
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node nodeInfo, _ int) string {
serviceableNodesInStr := lo.Map(lo.Values(serviceableNodes), func(node NodeInfo, _ int) string {
return node.String()
})
log.Warn("failed to select shard",
@ -178,33 +180,33 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
// Filter nodes based on excludeNodes
for _, node := range shardLeaders {
if !excludeNodes.Contain(node.nodeID) {
if node.serviceable {
serviceableNodes[node.nodeID] = node
if !excludeNodes.Contain(node.NodeID) {
if node.Serviceable {
serviceableNodes[node.NodeID] = node
}
candidateNodes[node.nodeID] = node
candidateNodes[node.NodeID] = node
}
}
if len(candidateNodes) == 0 {
err = merr.WrapErrChannelNotAvailable(workload.channel, "no available shard leaders")
return nodeInfo{}, err
err = merr.WrapErrChannelNotAvailable(workload.Channel, "no available shard leaders")
return NodeInfo{}, err
}
balancer.RegisterNodeInfo(lo.Values(candidateNodes))
// prefer serviceable nodes
var targetNodeID int64
if len(serviceableNodes) > 0 {
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(serviceableNodes), workload.nq)
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(serviceableNodes), workload.Nq)
} else {
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(candidateNodes), workload.nq)
targetNodeID, err = balancer.SelectNode(ctx, lo.Keys(candidateNodes), workload.Nq)
}
if err != nil {
return nodeInfo{}, err
return NodeInfo{}, err
}
if _, ok := candidateNodes[targetNodeID]; !ok {
err = merr.WrapErrNodeNotAvailable(targetNodeID)
return nodeInfo{}, err
return NodeInfo{}, err
}
return candidateNodes[targetNodeID], nil
@ -218,7 +220,7 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
withShardLeaderCache = false
targetNode, err = trySelectNode(withShardLeaderCache)
if err != nil {
return nodeInfo{}, err
return NodeInfo{}, err
}
}
@ -228,8 +230,8 @@ func (lb *LBPolicyImpl) selectNode(ctx context.Context, balancer LBBalancer, wor
// ExecuteWithRetry will choose a qn to execute the workload, and retry if failed, until reach the max retryTimes.
func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWorkload) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.String("channelName", workload.channel),
zap.Int64("collectionID", workload.CollectionID),
zap.String("channelName", workload.Channel),
)
var lastErr error
excludeNodes := typeutil.NewUniqueSet()
@ -238,7 +240,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
targetNode, err := lb.selectNode(ctx, balancer, workload, &excludeNodes)
if err != nil {
log.Warn("failed to select node for shard",
zap.Int64("nodeID", targetNode.nodeID),
zap.Int64("nodeID", targetNode.NodeID),
zap.Int64s("excluded", excludeNodes.Collect()),
zap.Error(err),
)
@ -248,33 +250,33 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
return true, err
}
// cancel work load which assign to the target node
defer balancer.CancelWorkload(targetNode.nodeID, workload.nq)
defer balancer.CancelWorkload(targetNode.NodeID, workload.Nq)
client, err := lb.clientMgr.GetClient(ctx, targetNode)
if err != nil {
log.Warn("search/query channel failed, node not available",
zap.Int64("nodeID", targetNode.nodeID),
zap.Int64("nodeID", targetNode.NodeID),
zap.Error(err))
excludeNodes.Insert(targetNode.nodeID)
excludeNodes.Insert(targetNode.NodeID)
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.nodeID, workload.channel)
lastErr = errors.Wrapf(err, "failed to get delegator %d for channel %s", targetNode.NodeID, workload.Channel)
return true, lastErr
}
err = workload.exec(ctx, targetNode.nodeID, client, workload.channel)
err = workload.Exec(ctx, targetNode.NodeID, client, workload.Channel)
if err != nil {
log.Warn("search/query channel failed",
zap.Int64("nodeID", targetNode.nodeID),
zap.Int64("nodeID", targetNode.NodeID),
zap.Error(err))
excludeNodes.Insert(targetNode.nodeID)
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.nodeID, workload.channel)
excludeNodes.Insert(targetNode.NodeID)
lastErr = errors.Wrapf(err, "failed to search/query delegator %d for channel %s", targetNode.NodeID, workload.Channel)
return true, lastErr
}
return true, nil
}
shardLeaders, err := lb.GetShard(ctx, workload.db, workload.collectionName, workload.collectionID, workload.channel, true)
shardLeaders, err := lb.GetShard(ctx, workload.Db, workload.CollectionName, workload.CollectionID, workload.Channel, true)
if err != nil {
log.Warn("failed to get shard leaders", zap.Error(err))
return err
@ -283,7 +285,7 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
err = retry.Handle(ctx, tryExecute, retry.Attempts(uint(retryTimes)))
if err != nil {
log.Warn("failed to execute",
zap.String("channel", workload.channel),
zap.String("channel", workload.Channel),
zap.Error(err))
}
@ -293,17 +295,17 @@ func (lb *LBPolicyImpl) ExecuteWithRetry(ctx context.Context, workload ChannelWo
// Execute will execute collection workload in parallel
func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", workload.collectionID),
zap.Int64("collectionID", workload.CollectionID),
)
channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true)
channelList, err := lb.GetShardLeaderList(ctx, workload.Db, workload.CollectionName, workload.CollectionID, true)
if err != nil {
log.Warn("failed to get shards", zap.Error(err))
return err
}
if len(channelList) == 0 {
log.Info("no shard leaders found", zap.Int64("collectionID", workload.collectionID))
return merr.WrapErrCollectionNotLoaded(workload.collectionID)
log.Info("no shard leaders found", zap.Int64("collectionID", workload.CollectionID))
return merr.WrapErrCollectionNotLoaded(workload.CollectionID)
}
wg, _ := errgroup.WithContext(ctx)
@ -311,12 +313,12 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
for _, channel := range channelList {
wg.Go(func() error {
return lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db,
collectionName: workload.collectionName,
collectionID: workload.collectionID,
channel: channel,
nq: workload.nq,
exec: workload.exec,
Db: workload.Db,
CollectionName: workload.CollectionName,
CollectionID: workload.CollectionID,
Channel: channel,
Nq: workload.Nq,
Exec: workload.Exec,
})
})
}
@ -325,7 +327,7 @@ func (lb *LBPolicyImpl) Execute(ctx context.Context, workload CollectionWorkLoad
// Execute will execute any one channel in collection workload
func (lb *LBPolicyImpl) ExecuteOneChannel(ctx context.Context, workload CollectionWorkLoad) error {
channelList, err := lb.GetShardLeaderList(ctx, workload.db, workload.collectionName, workload.collectionID, true)
channelList, err := lb.GetShardLeaderList(ctx, workload.Db, workload.CollectionName, workload.CollectionID, true)
if err != nil {
log.Ctx(ctx).Warn("failed to get shards", zap.Error(err))
return err
@ -334,15 +336,15 @@ func (lb *LBPolicyImpl) ExecuteOneChannel(ctx context.Context, workload Collecti
// let every request could retry at least twice, which could retry after update shard leader cache
for _, channel := range channelList {
return lb.ExecuteWithRetry(ctx, ChannelWorkload{
db: workload.db,
collectionName: workload.collectionName,
collectionID: workload.collectionID,
channel: channel,
nq: workload.nq,
exec: workload.exec,
Db: workload.Db,
CollectionName: workload.CollectionName,
CollectionID: workload.CollectionID,
Channel: channel,
Nq: workload.Nq,
Exec: workload.Exec,
})
}
return fmt.Errorf("no acitvate sheard leader exist for collection: %s", workload.collectionName)
return fmt.Errorf("no acitvate sheard leader exist for collection: %s", workload.CollectionName)
}
func (lb *LBPolicyImpl) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {

View File

@ -0,0 +1,664 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shardclient
import (
"context"
"reflect"
"testing"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type LBPolicySuite struct {
suite.Suite
qn *mocks.MockQueryNodeClient
mgr *MockShardClientManager
lbBalancer *MockLBBalancer
lbPolicy *LBPolicyImpl
nodeIDs []int64
nodes []NodeInfo
channels []string
dbName string
collectionName string
collectionID int64
}
func (s *LBPolicySuite) SetupSuite() {
paramtable.Init()
}
func (s *LBPolicySuite) SetupTest() {
s.nodeIDs = make([]int64, 0)
s.nodes = make([]NodeInfo, 0)
for i := 1; i <= 5; i++ {
s.nodeIDs = append(s.nodeIDs, int64(i))
s.nodes = append(s.nodes, NodeInfo{
NodeID: int64(i),
Address: "localhost",
Serviceable: true,
})
}
s.channels = []string{"channel1", "channel2"}
s.qn = mocks.NewMockQueryNodeClient(s.T())
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
s.mgr = NewMockShardClientManager(s.T())
s.lbBalancer = NewMockLBBalancer(s.T())
s.lbBalancer.EXPECT().Start(mock.Anything).Maybe()
s.lbBalancer.EXPECT().Close().Maybe()
s.lbPolicy = NewLBPolicyImpl(s.mgr)
s.lbPolicy.Start(context.Background())
s.lbPolicy.getBalancer = func() LBBalancer {
return s.lbBalancer
}
s.dbName = "test_lb_policy"
s.collectionName = "test_lb_policy"
s.collectionID = 100
}
func (s *LBPolicySuite) TearDownTest() {
s.lbPolicy.Close()
}
func (s *LBPolicySuite) TestSelectNode() {
ctx := context.Background()
// test select node success
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(5), nil)
excludeNodes := typeutil.NewUniqueSet()
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(5), targetNode.NodeID)
// test select node failed, then update shard leader cache and retry, expect success
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
// First call with cache fails
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), errors.New("fake err")).Once()
// Second call without cache succeeds
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(3), nil).Once()
excludeNodes = typeutil.NewUniqueSet()
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(3), targetNode.NodeID)
// test select node always fails, expected failure
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), merr.ErrNodeNotAvailable)
excludeNodes = typeutil.NewUniqueSet()
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test all nodes has been excluded, expected clear excludeNodes and try to select node again
excludeNodes = typeutil.NewUniqueSet(s.nodeIDs...)
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), merr.ErrNodeNotAvailable)
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test get shard leaders failed, retry to select node failed
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(nil, merr.ErrCollectionNotLoaded).Once()
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(nil, merr.ErrCollectionNotLoaded).Once()
excludeNodes = typeutil.NewUniqueSet()
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.ErrorIs(err, merr.ErrCollectionNotLoaded)
}
func (s *LBPolicySuite) TestExecuteWithRetry() {
ctx := context.Background()
// test execute success
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test select node failed, expected error
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), merr.ErrNodeNotAvailable)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.ErrorIs(err, merr.ErrNodeNotAvailable)
// test get client failed, and retry failed, expected error
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error"))
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.Error(err)
// test get client failed once, then retry success
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(nil, errors.New("fake error")).Once()
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test exec failed, then retry success
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := 0
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
counter++
if counter == 1 {
return errors.New("fake error")
}
return nil
},
})
s.NoError(err)
// test exec timeout
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil).Maybe()
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.Canceled).Once()
s.qn.EXPECT().Search(mock.Anything, mock.Anything).Return(nil, context.DeadlineExceeded)
err = s.lbPolicy.ExecuteWithRetry(ctx, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
_, err := qn.Search(ctx, nil)
return err
},
})
s.True(merr.IsCanceledOrTimeout(err))
}
func (s *LBPolicySuite) TestExecuteOneChannel() {
ctx := context.Background()
mockErr := errors.New("mock error")
// test all channel success
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil)
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test get shard leader failed
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(nil, mockErr)
err = s.lbPolicy.ExecuteOneChannel(ctx, CollectionWorkLoad{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.ErrorIs(err, mockErr)
}
func (s *LBPolicySuite) TestExecute() {
ctx := context.Background()
mockErr := errors.New("mock error")
// test all channel success
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil)
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil).Maybe()
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
err := s.lbPolicy.Execute(ctx, CollectionWorkLoad{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.NoError(err)
// test some channel failed
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil)
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, mock.Anything).Return(s.nodes, nil).Maybe()
s.mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil)
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, availableNodes []int64, nq int64) (int64, error) {
return availableNodes[0], nil
})
s.lbBalancer.EXPECT().CancelWorkload(mock.Anything, mock.Anything)
counter := atomic.NewInt64(0)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
// succeed in first execute
if counter.Add(1) == 1 {
return nil
}
return mockErr
},
})
s.Error(err)
s.Equal(int64(6), counter.Load())
// test get shard leader failed
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(nil, mockErr)
err = s.lbPolicy.Execute(ctx, CollectionWorkLoad{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Nq: 1,
Exec: func(ctx context.Context, ui UniqueID, qn types.QueryNodeClient, channel string) error {
return nil
},
})
s.ErrorIs(err, mockErr)
}
func (s *LBPolicySuite) TestUpdateCostMetrics() {
s.lbBalancer.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything)
s.lbPolicy.UpdateCostMetrics(1, &internalpb.CostAggregation{})
}
func (s *LBPolicySuite) TestNewLBPolicy() {
mgr := NewMockShardClientManager(s.T())
policy := NewLBPolicyImpl(mgr)
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*shardclient.LookAsideBalancer")
policy.Close()
params := paramtable.Get()
params.Save(params.ProxyCfg.ReplicaSelectionPolicy.Key, "round_robin")
policy = NewLBPolicyImpl(mgr)
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*shardclient.RoundRobinBalancer")
policy.Close()
params.Save(params.ProxyCfg.ReplicaSelectionPolicy.Key, "look_aside")
policy = NewLBPolicyImpl(mgr)
s.Equal(reflect.TypeOf(policy.getBalancer()).String(), "*shardclient.LookAsideBalancer")
policy.Close()
}
func (s *LBPolicySuite) TestGetShard() {
ctx := context.Background()
// ErrCollectionNotLoaded is not retriable, expected to fail fast
counter := atomic.NewInt64(0)
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).RunAndReturn(
func(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]NodeInfo, error) {
counter.Inc()
return nil, merr.ErrCollectionNotLoaded
})
_, err := s.lbPolicy.GetShard(ctx, s.dbName, s.collectionName, s.collectionID, s.channels[0], true)
s.Error(err)
s.Equal(int64(1), counter.Load())
// Normal case - success
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(s.nodes, nil)
shardLeaders, err := s.lbPolicy.GetShard(ctx, s.dbName, s.collectionName, s.collectionID, s.channels[0], true)
s.NoError(err)
s.Equal(len(s.nodes), len(shardLeaders))
}
func (s *LBPolicySuite) TestSelectNodeEdgeCases() {
ctx := context.Background()
// Test case 1: Empty shard leaders after refresh, should fail gracefully
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
excludeNodes := typeutil.NewUniqueSet(s.nodeIDs...)
_, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.Error(err)
// Test case 2: Single replica scenario - exclude it, refresh shows same single replica, should clear and succeed
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
singleNode := []NodeInfo{{NodeID: 1, Address: "localhost:9000", Serviceable: true}}
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(singleNode, nil).Once()
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(singleNode, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude the single node
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(1), targetNode.NodeID)
s.Equal(0, excludeNodes.Len()) // Should be cleared
// Test case 3: Mixed serviceable nodes - prefer serviceable over non-serviceable
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
mixedNodes := []NodeInfo{
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
{NodeID: 2, Address: "localhost:9001", Serviceable: false},
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
}
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(mixedNodes, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
// Should select from serviceable nodes only (node 3, since node 1 is excluded)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.MatchedBy(func(nodes []int64) bool {
return len(nodes) == 1 && nodes[0] == 3 // Only node 3 is serviceable and not excluded
}), mock.Anything).Return(int64(3), nil).Once()
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Exclude node 1, node 3 should be available
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(3), targetNode.NodeID)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
}
func (s *LBPolicySuite) TestGetShardLeaderList() {
ctx := context.Background()
// Test normal scenario with cache
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).Return(s.channels, nil)
channelList, err := s.lbPolicy.GetShardLeaderList(ctx, s.dbName, s.collectionName, s.collectionID, true)
s.NoError(err)
s.Equal(len(s.channels), len(channelList))
s.Contains(channelList, s.channels[0])
s.Contains(channelList, s.channels[1])
// Test without cache - should refresh from coordinator
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, false).Return(s.channels, nil)
channelList, err = s.lbPolicy.GetShardLeaderList(ctx, s.dbName, s.collectionName, s.collectionID, false)
s.NoError(err)
s.Equal(len(s.channels), len(channelList))
// Test error case - collection not loaded
counter := atomic.NewInt64(0)
s.mgr.ExpectedCalls = nil
s.mgr.EXPECT().GetShardLeaderList(mock.Anything, s.dbName, s.collectionName, s.collectionID, true).RunAndReturn(
func(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) {
counter.Inc()
return nil, merr.ErrCollectionNotLoaded
})
_, err = s.lbPolicy.GetShardLeaderList(ctx, s.dbName, s.collectionName, s.collectionID, true)
s.Error(err)
s.ErrorIs(err, merr.ErrCollectionNotLoaded)
s.Equal(int64(1), counter.Load())
}
func (s *LBPolicySuite) TestSelectNodeWithExcludeClearing() {
ctx := context.Background()
// Test exclude nodes clearing when all replicas are excluded after cache refresh
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
twoNodes := []NodeInfo{
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
}
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(twoNodes, nil).Once()
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(twoNodes, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil).Once()
excludeNodes := typeutil.NewUniqueSet(int64(1), int64(2)) // Exclude all available nodes
targetNode, err := s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(1), targetNode.NodeID)
s.Equal(0, excludeNodes.Len()) // Should be cleared when all replicas were excluded
// Test exclude nodes NOT cleared when only partial replicas are excluded
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
threeNodes := []NodeInfo{
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
}
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return(threeNodes, nil).Once()
s.lbBalancer.EXPECT().RegisterNodeInfo(mock.Anything)
s.lbBalancer.EXPECT().SelectNode(mock.Anything, mock.Anything, mock.Anything).Return(int64(2), nil).Once()
excludeNodes = typeutil.NewUniqueSet(int64(1)) // Only exclude node 1
targetNode, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.NoError(err)
s.Equal(int64(2), targetNode.NodeID)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared as not all replicas were excluded
// Test empty shard leaders scenario
s.mgr.ExpectedCalls = nil
s.lbBalancer.ExpectedCalls = nil
s.mgr.EXPECT().GetShard(mock.Anything, true, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
s.mgr.EXPECT().GetShard(mock.Anything, false, s.dbName, s.collectionName, s.collectionID, s.channels[0]).Return([]NodeInfo{}, nil).Once()
excludeNodes = typeutil.NewUniqueSet(int64(1))
_, err = s.lbPolicy.selectNode(ctx, s.lbBalancer, ChannelWorkload{
Db: s.dbName,
CollectionName: s.collectionName,
CollectionID: s.collectionID,
Channel: s.channels[0],
Nq: 1,
}, &excludeNodes)
s.Error(err)
s.Equal(1, excludeNodes.Len()) // Should NOT be cleared for empty shard leaders
}
func TestLBPolicySuite(t *testing.T) {
suite.Run(t, new(LBPolicySuite))
}

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package shardclient
import (
"context"
@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
@ -42,9 +43,9 @@ type CostMetrics struct {
}
type LookAsideBalancer struct {
clientMgr shardClientMgr
clientMgr ShardClientMgr
knownNodeInfos *typeutil.ConcurrentMap[int64, nodeInfo]
knownNodeInfos *typeutil.ConcurrentMap[int64, NodeInfo]
metricsMap *typeutil.ConcurrentMap[int64, *CostMetrics]
// query node id -> number of consecutive heartbeat failures
failedHeartBeatCounter *typeutil.ConcurrentMap[int64, *atomic.Int64]
@ -62,18 +63,18 @@ type LookAsideBalancer struct {
workloadToleranceFactor float64
}
func NewLookAsideBalancer(clientMgr shardClientMgr) *LookAsideBalancer {
func NewLookAsideBalancer(clientMgr ShardClientMgr) *LookAsideBalancer {
balancer := &LookAsideBalancer{
clientMgr: clientMgr,
knownNodeInfos: typeutil.NewConcurrentMap[int64, nodeInfo](),
knownNodeInfos: typeutil.NewConcurrentMap[int64, NodeInfo](),
metricsMap: typeutil.NewConcurrentMap[int64, *CostMetrics](),
failedHeartBeatCounter: typeutil.NewConcurrentMap[int64, *atomic.Int64](),
closeCh: make(chan struct{}),
}
balancer.metricExpireInterval = Params.ProxyCfg.CostMetricsExpireTime.GetAsInt64()
balancer.checkWorkloadRequestNum = Params.ProxyCfg.CheckWorkloadRequestNum.GetAsInt64()
balancer.workloadToleranceFactor = Params.ProxyCfg.WorkloadToleranceFactor.GetAsFloat()
balancer.metricExpireInterval = paramtable.Get().ProxyCfg.CostMetricsExpireTime.GetAsInt64()
balancer.checkWorkloadRequestNum = paramtable.Get().ProxyCfg.CheckWorkloadRequestNum.GetAsInt64()
balancer.workloadToleranceFactor = paramtable.Get().ProxyCfg.WorkloadToleranceFactor.GetAsFloat()
return balancer
}
@ -90,9 +91,9 @@ func (b *LookAsideBalancer) Close() {
})
}
func (b *LookAsideBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {
func (b *LookAsideBalancer) RegisterNodeInfo(nodeInfos []NodeInfo) {
for _, node := range nodeInfos {
b.knownNodeInfos.Insert(node.nodeID, node)
b.knownNodeInfos.Insert(node.NodeID, node)
}
}
@ -227,7 +228,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60)
defer b.wg.Done()
checkHealthInterval := Params.ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond)
checkHealthInterval := paramtable.Get().ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond)
ticker := time.NewTicker(checkHealthInterval)
defer ticker.Stop()
log.Info("Start check query node health loop")
@ -241,11 +242,11 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
case <-ticker.C:
var futures []*conc.Future[any]
now := time.Now()
b.knownNodeInfos.Range(func(node int64, info nodeInfo) bool {
b.knownNodeInfos.Range(func(node int64, info NodeInfo) bool {
futures = append(futures, pool.Submit(func() (any, error) {
metrics, ok := b.metricsMap.Get(node)
if !ok || now.UnixMilli()-metrics.ts.Load() > checkHealthInterval.Milliseconds() {
checkTimeout := Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond)
checkTimeout := paramtable.Get().ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), checkTimeout)
defer cancel()
@ -301,14 +302,14 @@ func (b *LookAsideBalancer) trySetQueryNodeUnReachable(node int64, err error) {
zap.Int64("times", failures.Load()),
zap.Error(err))
if failures.Load() < Params.ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() {
if failures.Load() < paramtable.Get().ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() {
return
}
// if the total time of consecutive heartbeat failures reach the session.ttl, remove the offline query node
limit := Params.CommonCfg.SessionTTL.GetAsDuration(time.Second).Seconds() /
Params.ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond).Seconds()
if failures.Load() > Params.ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() && float64(failures.Load()) >= limit {
limit := paramtable.Get().CommonCfg.SessionTTL.GetAsDuration(time.Second).Seconds() /
paramtable.Get().ProxyCfg.HealthCheckTimeout.GetAsDuration(time.Millisecond).Seconds()
if failures.Load() > paramtable.Get().ProxyCfg.RetryTimesOnHealthCheck.GetAsInt64() && float64(failures.Load()) >= limit {
log.Info("the heartbeat failures has reach it's upper limit, remove the query node",
zap.Int64("nodeID", node))
// stop the heartbeat

View File

@ -14,7 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package shardclient
import (
"context"
@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
type LookAsideBalancerSuite struct {
@ -308,12 +309,12 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
},
}, nil).Maybe()
suite.clientMgr.ExpectedCalls = nil
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni nodeInfo) (types.QueryNodeClient, error) {
if ni.nodeID == 1 {
suite.clientMgr.EXPECT().GetClient(mock.Anything, mock.Anything).RunAndReturn(func(ctx context.Context, ni NodeInfo) (types.QueryNodeClient, error) {
if ni.NodeID == 1 {
return qn, nil
}
if ni.nodeID == 2 {
if ni.NodeID == 2 {
return qn2, nil
}
return nil, errors.New("unexpected node")
@ -323,19 +324,19 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
metrics1.ts.Store(time.Now().UnixMilli())
metrics1.unavailable.Store(true)
suite.balancer.metricsMap.Insert(1, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
suite.balancer.RegisterNodeInfo([]NodeInfo{
{
nodeID: 1,
NodeID: 1,
},
})
metrics2 := &CostMetrics{}
metrics2.ts.Store(time.Now().UnixMilli())
metrics2.unavailable.Store(true)
suite.balancer.metricsMap.Insert(2, metrics2)
suite.balancer.knownNodeInfos.Insert(2, nodeInfo{})
suite.balancer.RegisterNodeInfo([]nodeInfo{
suite.balancer.knownNodeInfos.Insert(2, NodeInfo{})
suite.balancer.RegisterNodeInfo([]NodeInfo{
{
nodeID: 2,
NodeID: 2,
},
})
suite.Eventually(func() bool {
@ -363,9 +364,9 @@ func (suite *LookAsideBalancerSuite) TestGetClientFailed() {
metrics1.ts.Store(time.Now().UnixMilli())
metrics1.unavailable.Store(true)
suite.balancer.metricsMap.Insert(2, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
suite.balancer.RegisterNodeInfo([]NodeInfo{
{
nodeID: 2,
NodeID: 2,
},
})
@ -398,9 +399,9 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
metrics1 := &CostMetrics{}
metrics1.ts.Store(time.Now().UnixMilli())
suite.balancer.metricsMap.Insert(3, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
suite.balancer.RegisterNodeInfo([]NodeInfo{
{
nodeID: 3,
NodeID: 3,
},
})
suite.Eventually(func() bool {
@ -415,8 +416,9 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
}
func (suite *LookAsideBalancerSuite) TestNodeOffline() {
Params.Save(Params.CommonCfg.SessionTTL.Key, "10")
Params.Save(Params.ProxyCfg.HealthCheckTimeout.Key, "1000")
params := paramtable.Get()
params.Save(params.CommonCfg.SessionTTL.Key, "10")
params.Save(params.ProxyCfg.HealthCheckTimeout.Key, "1000")
// mock qn down for a while and then recover
qn3 := mocks.NewMockQueryNodeClient(suite.T())
suite.clientMgr.ExpectedCalls = nil
@ -430,9 +432,9 @@ func (suite *LookAsideBalancerSuite) TestNodeOffline() {
metrics1 := &CostMetrics{}
metrics1.ts.Store(time.Now().UnixMilli())
suite.balancer.metricsMap.Insert(3, metrics1)
suite.balancer.RegisterNodeInfo([]nodeInfo{
suite.balancer.RegisterNodeInfo([]NodeInfo{
{
nodeID: 3,
NodeID: 3,
},
})
suite.Eventually(func() bool {

View File

@ -0,0 +1,337 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shardclient
import (
"context"
"fmt"
"strings"
"sync"
"time"
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/registry"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type ShardClientMgr interface {
GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]NodeInfo, error)
GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error)
DeprecateShardCache(database, collectionName string)
InvalidateShardLeaderCache(collections []int64)
ListShardLocation() map[int64]NodeInfo
RemoveDatabase(database string)
GetClient(ctx context.Context, nodeInfo NodeInfo) (types.QueryNodeClient, error)
SetClientCreatorFunc(creator queryNodeCreatorFunc)
Start()
Close()
}
type shardClientMgrImpl struct {
clients *typeutil.ConcurrentMap[UniqueID, *shardClient]
clientCreator queryNodeCreatorFunc
closeCh chan struct{}
purgeInterval time.Duration
expiredDuration time.Duration
mixCoord types.MixCoordClient
leaderMut sync.RWMutex
collLeader map[string]map[string]*shardLeaders // database -> collectionName -> collection_leaders
}
const (
defaultPurgeInterval = 600 * time.Second
defaultExpiredDuration = 60 * time.Minute
)
// SessionOpt provides a way to set params in ShardClientMgr
type shardClientMgrOpt func(s ShardClientMgr)
func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
return func(s ShardClientMgr) { s.SetClientCreatorFunc(creator) }
}
func DefaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return registry.GetInMemoryResolver().ResolveQueryNode(ctx, addr, nodeID)
}
// NewShardClientMgr creates a new shardClientMgr
func NewShardClientMgr(mixCoord types.MixCoordClient, options ...shardClientMgrOpt) *shardClientMgrImpl {
s := &shardClientMgrImpl{
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
clientCreator: DefaultQueryNodeClientCreator,
closeCh: make(chan struct{}),
purgeInterval: defaultPurgeInterval,
expiredDuration: defaultExpiredDuration,
collLeader: make(map[string]map[string]*shardLeaders),
mixCoord: mixCoord,
}
for _, opt := range options {
opt(s)
}
return s
}
func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
c.clientCreator = creator
}
func (m *shardClientMgrImpl) GetShard(ctx context.Context, withCache bool, database, collectionName string, collectionID int64, channel string) ([]NodeInfo, error) {
method := "GetShard"
// check cache first
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
if cacheShardLeaders == nil || !withCache {
// refresh shard leader cache
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
cacheShardLeaders = newShardLeaders
}
return cacheShardLeaders.Get(channel), nil
}
func (m *shardClientMgrImpl) GetShardLeaderList(ctx context.Context, database, collectionName string, collectionID int64, withCache bool) ([]string, error) {
method := "GetShardLeaderList"
// check cache first
cacheShardLeaders := m.getCachedShardLeaders(database, collectionName, method)
if cacheShardLeaders == nil || !withCache {
// refresh shard leader cache
newShardLeaders, err := m.updateShardLocationCache(ctx, database, collectionName, collectionID)
if err != nil {
return nil, err
}
cacheShardLeaders = newShardLeaders
}
return cacheShardLeaders.GetShardLeaderList(), nil
}
func (m *shardClientMgrImpl) getCachedShardLeaders(database, collectionName, caller string) *shardLeaders {
m.leaderMut.RLock()
var cacheShardLeaders *shardLeaders
db, ok := m.collLeader[database]
if !ok {
cacheShardLeaders = nil
} else {
cacheShardLeaders = db[collectionName]
}
m.leaderMut.RUnlock()
if cacheShardLeaders != nil {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheHitLabel).Inc()
} else {
metrics.ProxyCacheStatsCounter.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), caller, metrics.CacheMissLabel).Inc()
}
return cacheShardLeaders
}
func (m *shardClientMgrImpl) updateShardLocationCache(ctx context.Context, database, collectionName string, collectionID int64) (*shardLeaders, error) {
log := log.Ctx(ctx).With(
zap.String("db", database),
zap.String("collectionName", collectionName),
zap.Int64("collectionID", collectionID))
method := "updateShardLocationCache"
tr := timerecord.NewTimeRecorder(method)
defer metrics.ProxyUpdateCacheLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), method).
Observe(float64(tr.ElapseSpan().Milliseconds()))
req := &querypb.GetShardLeadersRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_GetShardLeaders),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionID: collectionID,
WithUnserviceableShards: true,
}
resp, err := m.mixCoord.GetShardLeaders(ctx, req)
if err := merr.CheckRPCCall(resp.GetStatus(), err); err != nil {
log.Error("failed to get shard locations",
zap.Int64("collectionID", collectionID),
zap.Error(err))
return nil, err
}
shards := parseShardLeaderList2QueryNode(resp.GetShards())
// convert shards map to string for logging
if log.Logger.Level() == zap.DebugLevel {
shardStr := make([]string, 0, len(shards))
for channel, nodes := range shards {
nodeStrs := make([]string, 0, len(nodes))
for _, node := range nodes {
nodeStrs = append(nodeStrs, node.String())
}
shardStr = append(shardStr, fmt.Sprintf("%s:[%s]", channel, strings.Join(nodeStrs, ", ")))
}
log.Debug("update shard leader cache", zap.String("newShardLeaders", strings.Join(shardStr, ", ")))
}
newShardLeaders := &shardLeaders{
collectionID: collectionID,
shardLeaders: shards,
idx: atomic.NewInt64(0),
}
m.leaderMut.Lock()
if _, ok := m.collLeader[database]; !ok {
m.collLeader[database] = make(map[string]*shardLeaders)
}
m.collLeader[database][collectionName] = newShardLeaders
m.leaderMut.Unlock()
return newShardLeaders, nil
}
func parseShardLeaderList2QueryNode(shardsLeaders []*querypb.ShardLeadersList) map[string][]NodeInfo {
shard2QueryNodes := make(map[string][]NodeInfo)
for _, leaders := range shardsLeaders {
qns := make([]NodeInfo, len(leaders.GetNodeIds()))
for j := range qns {
qns[j] = NodeInfo{leaders.GetNodeIds()[j], leaders.GetNodeAddrs()[j], leaders.GetServiceable()[j]}
}
shard2QueryNodes[leaders.GetChannelName()] = qns
}
return shard2QueryNodes
}
// used for Garbage collection shard client
func (m *shardClientMgrImpl) ListShardLocation() map[int64]NodeInfo {
m.leaderMut.RLock()
defer m.leaderMut.RUnlock()
shardLeaderInfo := make(map[int64]NodeInfo)
for _, dbInfo := range m.collLeader {
for _, shardLeaders := range dbInfo {
for _, nodeInfos := range shardLeaders.shardLeaders {
for _, node := range nodeInfos {
shardLeaderInfo[node.NodeID] = node
}
}
}
}
return shardLeaderInfo
}
func (m *shardClientMgrImpl) RemoveDatabase(database string) {
m.leaderMut.Lock()
defer m.leaderMut.Unlock()
delete(m.collLeader, database)
}
// DeprecateShardCache clear the shard leader cache of a collection
func (m *shardClientMgrImpl) DeprecateShardCache(database, collectionName string) {
log.Info("deprecate shard cache for collection", zap.String("collectionName", collectionName))
m.leaderMut.Lock()
defer m.leaderMut.Unlock()
dbInfo, ok := m.collLeader[database]
if ok {
delete(dbInfo, collectionName)
if len(dbInfo) == 0 {
delete(m.collLeader, database)
}
}
}
// InvalidateShardLeaderCache called when Shard leader balance happened
func (m *shardClientMgrImpl) InvalidateShardLeaderCache(collections []int64) {
log.Info("Invalidate shard cache for collections", zap.Int64s("collectionIDs", collections))
m.leaderMut.Lock()
defer m.leaderMut.Unlock()
collectionSet := typeutil.NewUniqueSet(collections...)
for dbName, dbInfo := range m.collLeader {
for collectionName, shardLeaders := range dbInfo {
if collectionSet.Contain(shardLeaders.collectionID) {
delete(dbInfo, collectionName)
}
}
if len(dbInfo) == 0 {
delete(m.collLeader, dbName)
}
}
}
func (c *shardClientMgrImpl) GetClient(ctx context.Context, info NodeInfo) (types.QueryNodeClient, error) {
client, _ := c.clients.GetOrInsert(info.NodeID, newShardClient(info, c.clientCreator, c.expiredDuration))
return client.getClient(ctx)
}
// PurgeClient purges client if it is not used for a long time
func (c *shardClientMgrImpl) PurgeClient() {
ticker := time.NewTicker(c.purgeInterval)
defer ticker.Stop()
for {
select {
case <-c.closeCh:
return
case <-ticker.C:
shardLocations := c.ListShardLocation()
c.clients.Range(func(key UniqueID, value *shardClient) bool {
if _, ok := shardLocations[key]; !ok {
// if the client is not used for more than 1 hour, and it's not a delegator anymore, should remove it
if value.isExpired() {
closed := value.Close(false)
if closed {
c.clients.Remove(key)
log.Info("remove idle node client", zap.Int64("nodeID", key))
}
}
}
return true
})
}
}
}
func (c *shardClientMgrImpl) Start() {
go c.PurgeClient()
}
// Close release clients
func (c *shardClientMgrImpl) Close() {
close(c.closeCh)
c.clients.Range(func(key UniqueID, value *shardClient) bool {
value.Close(true)
c.clients.Remove(key)
return true
})
}

View File

@ -1,6 +1,6 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package proxy
package shardclient
import (
context "context"
@ -89,7 +89,7 @@ func (_c *MockLBBalancer_Close_Call) RunAndReturn(run func()) *MockLBBalancer_Cl
}
// RegisterNodeInfo provides a mock function with given fields: nodeInfos
func (_m *MockLBBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {
func (_m *MockLBBalancer) RegisterNodeInfo(nodeInfos []NodeInfo) {
_m.Called(nodeInfos)
}
@ -99,14 +99,14 @@ type MockLBBalancer_RegisterNodeInfo_Call struct {
}
// RegisterNodeInfo is a helper method to define mock.On call
// - nodeInfos []nodeInfo
// - nodeInfos []NodeInfo
func (_e *MockLBBalancer_Expecter) RegisterNodeInfo(nodeInfos interface{}) *MockLBBalancer_RegisterNodeInfo_Call {
return &MockLBBalancer_RegisterNodeInfo_Call{Call: _e.mock.On("RegisterNodeInfo", nodeInfos)}
}
func (_c *MockLBBalancer_RegisterNodeInfo_Call) Run(run func(nodeInfos []nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
func (_c *MockLBBalancer_RegisterNodeInfo_Call) Run(run func(nodeInfos []NodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]nodeInfo))
run(args[0].([]NodeInfo))
})
return _c
}
@ -116,7 +116,7 @@ func (_c *MockLBBalancer_RegisterNodeInfo_Call) Return() *MockLBBalancer_Registe
return _c
}
func (_c *MockLBBalancer_RegisterNodeInfo_Call) RunAndReturn(run func([]nodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
func (_c *MockLBBalancer_RegisterNodeInfo_Call) RunAndReturn(run func([]NodeInfo)) *MockLBBalancer_RegisterNodeInfo_Call {
_c.Run(run)
return _c
}

View File

@ -1,6 +1,6 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package proxy
package shardclient
import (
context "context"

View File

@ -0,0 +1,465 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package shardclient
import (
context "context"
types "github.com/milvus-io/milvus/internal/types"
mock "github.com/stretchr/testify/mock"
)
// MockShardClientManager is an autogenerated mock type for the ShardClientMgr type
type MockShardClientManager struct {
mock.Mock
}
type MockShardClientManager_Expecter struct {
mock *mock.Mock
}
func (_m *MockShardClientManager) EXPECT() *MockShardClientManager_Expecter {
return &MockShardClientManager_Expecter{mock: &_m.Mock}
}
// Close provides a mock function with no fields
func (_m *MockShardClientManager) Close() {
_m.Called()
}
// MockShardClientManager_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close'
type MockShardClientManager_Close_Call struct {
*mock.Call
}
// Close is a helper method to define mock.On call
func (_e *MockShardClientManager_Expecter) Close() *MockShardClientManager_Close_Call {
return &MockShardClientManager_Close_Call{Call: _e.mock.On("Close")}
}
func (_c *MockShardClientManager_Close_Call) Run(run func()) *MockShardClientManager_Close_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardClientManager_Close_Call) Return() *MockShardClientManager_Close_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_Close_Call) RunAndReturn(run func()) *MockShardClientManager_Close_Call {
_c.Run(run)
return _c
}
// DeprecateShardCache provides a mock function with given fields: database, collectionName
func (_m *MockShardClientManager) DeprecateShardCache(database string, collectionName string) {
_m.Called(database, collectionName)
}
// MockShardClientManager_DeprecateShardCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeprecateShardCache'
type MockShardClientManager_DeprecateShardCache_Call struct {
*mock.Call
}
// DeprecateShardCache is a helper method to define mock.On call
// - database string
// - collectionName string
func (_e *MockShardClientManager_Expecter) DeprecateShardCache(database interface{}, collectionName interface{}) *MockShardClientManager_DeprecateShardCache_Call {
return &MockShardClientManager_DeprecateShardCache_Call{Call: _e.mock.On("DeprecateShardCache", database, collectionName)}
}
func (_c *MockShardClientManager_DeprecateShardCache_Call) Run(run func(database string, collectionName string)) *MockShardClientManager_DeprecateShardCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string), args[1].(string))
})
return _c
}
func (_c *MockShardClientManager_DeprecateShardCache_Call) Return() *MockShardClientManager_DeprecateShardCache_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_DeprecateShardCache_Call) RunAndReturn(run func(string, string)) *MockShardClientManager_DeprecateShardCache_Call {
_c.Run(run)
return _c
}
// GetClient provides a mock function with given fields: ctx, nodeInfo
func (_m *MockShardClientManager) GetClient(ctx context.Context, nodeInfo NodeInfo) (types.QueryNodeClient, error) {
ret := _m.Called(ctx, nodeInfo)
if len(ret) == 0 {
panic("no return value specified for GetClient")
}
var r0 types.QueryNodeClient
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, NodeInfo) (types.QueryNodeClient, error)); ok {
return rf(ctx, nodeInfo)
}
if rf, ok := ret.Get(0).(func(context.Context, NodeInfo) types.QueryNodeClient); ok {
r0 = rf(ctx, nodeInfo)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(types.QueryNodeClient)
}
}
if rf, ok := ret.Get(1).(func(context.Context, NodeInfo) error); ok {
r1 = rf(ctx, nodeInfo)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardClientManager_GetClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetClient'
type MockShardClientManager_GetClient_Call struct {
*mock.Call
}
// GetClient is a helper method to define mock.On call
// - ctx context.Context
// - nodeInfo NodeInfo
func (_e *MockShardClientManager_Expecter) GetClient(ctx interface{}, nodeInfo interface{}) *MockShardClientManager_GetClient_Call {
return &MockShardClientManager_GetClient_Call{Call: _e.mock.On("GetClient", ctx, nodeInfo)}
}
func (_c *MockShardClientManager_GetClient_Call) Run(run func(ctx context.Context, nodeInfo NodeInfo)) *MockShardClientManager_GetClient_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(NodeInfo))
})
return _c
}
func (_c *MockShardClientManager_GetClient_Call) Return(_a0 types.QueryNodeClient, _a1 error) *MockShardClientManager_GetClient_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockShardClientManager_GetClient_Call) RunAndReturn(run func(context.Context, NodeInfo) (types.QueryNodeClient, error)) *MockShardClientManager_GetClient_Call {
_c.Call.Return(run)
return _c
}
// GetShard provides a mock function with given fields: ctx, withCache, database, collectionName, collectionID, channel
func (_m *MockShardClientManager) GetShard(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string) ([]NodeInfo, error) {
ret := _m.Called(ctx, withCache, database, collectionName, collectionID, channel)
if len(ret) == 0 {
panic("no return value specified for GetShard")
}
var r0 []NodeInfo
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) ([]NodeInfo, error)); ok {
return rf(ctx, withCache, database, collectionName, collectionID, channel)
}
if rf, ok := ret.Get(0).(func(context.Context, bool, string, string, int64, string) []NodeInfo); ok {
r0 = rf(ctx, withCache, database, collectionName, collectionID, channel)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]NodeInfo)
}
}
if rf, ok := ret.Get(1).(func(context.Context, bool, string, string, int64, string) error); ok {
r1 = rf(ctx, withCache, database, collectionName, collectionID, channel)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardClientManager_GetShard_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShard'
type MockShardClientManager_GetShard_Call struct {
*mock.Call
}
// GetShard is a helper method to define mock.On call
// - ctx context.Context
// - withCache bool
// - database string
// - collectionName string
// - collectionID int64
// - channel string
func (_e *MockShardClientManager_Expecter) GetShard(ctx interface{}, withCache interface{}, database interface{}, collectionName interface{}, collectionID interface{}, channel interface{}) *MockShardClientManager_GetShard_Call {
return &MockShardClientManager_GetShard_Call{Call: _e.mock.On("GetShard", ctx, withCache, database, collectionName, collectionID, channel)}
}
func (_c *MockShardClientManager_GetShard_Call) Run(run func(ctx context.Context, withCache bool, database string, collectionName string, collectionID int64, channel string)) *MockShardClientManager_GetShard_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(bool), args[2].(string), args[3].(string), args[4].(int64), args[5].(string))
})
return _c
}
func (_c *MockShardClientManager_GetShard_Call) Return(_a0 []NodeInfo, _a1 error) *MockShardClientManager_GetShard_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockShardClientManager_GetShard_Call) RunAndReturn(run func(context.Context, bool, string, string, int64, string) ([]NodeInfo, error)) *MockShardClientManager_GetShard_Call {
_c.Call.Return(run)
return _c
}
// GetShardLeaderList provides a mock function with given fields: ctx, database, collectionName, collectionID, withCache
func (_m *MockShardClientManager) GetShardLeaderList(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool) ([]string, error) {
ret := _m.Called(ctx, database, collectionName, collectionID, withCache)
if len(ret) == 0 {
panic("no return value specified for GetShardLeaderList")
}
var r0 []string
var r1 error
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) ([]string, error)); ok {
return rf(ctx, database, collectionName, collectionID, withCache)
}
if rf, ok := ret.Get(0).(func(context.Context, string, string, int64, bool) []string); ok {
r0 = rf(ctx, database, collectionName, collectionID, withCache)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
}
}
if rf, ok := ret.Get(1).(func(context.Context, string, string, int64, bool) error); ok {
r1 = rf(ctx, database, collectionName, collectionID, withCache)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// MockShardClientManager_GetShardLeaderList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetShardLeaderList'
type MockShardClientManager_GetShardLeaderList_Call struct {
*mock.Call
}
// GetShardLeaderList is a helper method to define mock.On call
// - ctx context.Context
// - database string
// - collectionName string
// - collectionID int64
// - withCache bool
func (_e *MockShardClientManager_Expecter) GetShardLeaderList(ctx interface{}, database interface{}, collectionName interface{}, collectionID interface{}, withCache interface{}) *MockShardClientManager_GetShardLeaderList_Call {
return &MockShardClientManager_GetShardLeaderList_Call{Call: _e.mock.On("GetShardLeaderList", ctx, database, collectionName, collectionID, withCache)}
}
func (_c *MockShardClientManager_GetShardLeaderList_Call) Run(run func(ctx context.Context, database string, collectionName string, collectionID int64, withCache bool)) *MockShardClientManager_GetShardLeaderList_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(int64), args[4].(bool))
})
return _c
}
func (_c *MockShardClientManager_GetShardLeaderList_Call) Return(_a0 []string, _a1 error) *MockShardClientManager_GetShardLeaderList_Call {
_c.Call.Return(_a0, _a1)
return _c
}
func (_c *MockShardClientManager_GetShardLeaderList_Call) RunAndReturn(run func(context.Context, string, string, int64, bool) ([]string, error)) *MockShardClientManager_GetShardLeaderList_Call {
_c.Call.Return(run)
return _c
}
// InvalidateShardLeaderCache provides a mock function with given fields: collections
func (_m *MockShardClientManager) InvalidateShardLeaderCache(collections []int64) {
_m.Called(collections)
}
// MockShardClientManager_InvalidateShardLeaderCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateShardLeaderCache'
type MockShardClientManager_InvalidateShardLeaderCache_Call struct {
*mock.Call
}
// InvalidateShardLeaderCache is a helper method to define mock.On call
// - collections []int64
func (_e *MockShardClientManager_Expecter) InvalidateShardLeaderCache(collections interface{}) *MockShardClientManager_InvalidateShardLeaderCache_Call {
return &MockShardClientManager_InvalidateShardLeaderCache_Call{Call: _e.mock.On("InvalidateShardLeaderCache", collections)}
}
func (_c *MockShardClientManager_InvalidateShardLeaderCache_Call) Run(run func(collections []int64)) *MockShardClientManager_InvalidateShardLeaderCache_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].([]int64))
})
return _c
}
func (_c *MockShardClientManager_InvalidateShardLeaderCache_Call) Return() *MockShardClientManager_InvalidateShardLeaderCache_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_InvalidateShardLeaderCache_Call) RunAndReturn(run func([]int64)) *MockShardClientManager_InvalidateShardLeaderCache_Call {
_c.Run(run)
return _c
}
// ListShardLocation provides a mock function with no fields
func (_m *MockShardClientManager) ListShardLocation() map[int64]NodeInfo {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for ListShardLocation")
}
var r0 map[int64]NodeInfo
if rf, ok := ret.Get(0).(func() map[int64]NodeInfo); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[int64]NodeInfo)
}
}
return r0
}
// MockShardClientManager_ListShardLocation_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListShardLocation'
type MockShardClientManager_ListShardLocation_Call struct {
*mock.Call
}
// ListShardLocation is a helper method to define mock.On call
func (_e *MockShardClientManager_Expecter) ListShardLocation() *MockShardClientManager_ListShardLocation_Call {
return &MockShardClientManager_ListShardLocation_Call{Call: _e.mock.On("ListShardLocation")}
}
func (_c *MockShardClientManager_ListShardLocation_Call) Run(run func()) *MockShardClientManager_ListShardLocation_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardClientManager_ListShardLocation_Call) Return(_a0 map[int64]NodeInfo) *MockShardClientManager_ListShardLocation_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockShardClientManager_ListShardLocation_Call) RunAndReturn(run func() map[int64]NodeInfo) *MockShardClientManager_ListShardLocation_Call {
_c.Call.Return(run)
return _c
}
// RemoveDatabase provides a mock function with given fields: database
func (_m *MockShardClientManager) RemoveDatabase(database string) {
_m.Called(database)
}
// MockShardClientManager_RemoveDatabase_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveDatabase'
type MockShardClientManager_RemoveDatabase_Call struct {
*mock.Call
}
// RemoveDatabase is a helper method to define mock.On call
// - database string
func (_e *MockShardClientManager_Expecter) RemoveDatabase(database interface{}) *MockShardClientManager_RemoveDatabase_Call {
return &MockShardClientManager_RemoveDatabase_Call{Call: _e.mock.On("RemoveDatabase", database)}
}
func (_c *MockShardClientManager_RemoveDatabase_Call) Run(run func(database string)) *MockShardClientManager_RemoveDatabase_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
})
return _c
}
func (_c *MockShardClientManager_RemoveDatabase_Call) Return() *MockShardClientManager_RemoveDatabase_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_RemoveDatabase_Call) RunAndReturn(run func(string)) *MockShardClientManager_RemoveDatabase_Call {
_c.Run(run)
return _c
}
// SetClientCreatorFunc provides a mock function with given fields: creator
func (_m *MockShardClientManager) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
_m.Called(creator)
}
// MockShardClientManager_SetClientCreatorFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetClientCreatorFunc'
type MockShardClientManager_SetClientCreatorFunc_Call struct {
*mock.Call
}
// SetClientCreatorFunc is a helper method to define mock.On call
// - creator queryNodeCreatorFunc
func (_e *MockShardClientManager_Expecter) SetClientCreatorFunc(creator interface{}) *MockShardClientManager_SetClientCreatorFunc_Call {
return &MockShardClientManager_SetClientCreatorFunc_Call{Call: _e.mock.On("SetClientCreatorFunc", creator)}
}
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Run(run func(creator queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(queryNodeCreatorFunc))
})
return _c
}
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) Return() *MockShardClientManager_SetClientCreatorFunc_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_SetClientCreatorFunc_Call) RunAndReturn(run func(queryNodeCreatorFunc)) *MockShardClientManager_SetClientCreatorFunc_Call {
_c.Run(run)
return _c
}
// Start provides a mock function with no fields
func (_m *MockShardClientManager) Start() {
_m.Called()
}
// MockShardClientManager_Start_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Start'
type MockShardClientManager_Start_Call struct {
*mock.Call
}
// Start is a helper method to define mock.On call
func (_e *MockShardClientManager_Expecter) Start() *MockShardClientManager_Start_Call {
return &MockShardClientManager_Start_Call{Call: _e.mock.On("Start")}
}
func (_c *MockShardClientManager_Start_Call) Run(run func()) *MockShardClientManager_Start_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockShardClientManager_Start_Call) Return() *MockShardClientManager_Start_Call {
_c.Call.Return()
return _c
}
func (_c *MockShardClientManager_Start_Call) RunAndReturn(run func()) *MockShardClientManager_Start_Call {
_c.Run(run)
return _c
}
// NewMockShardClientManager creates a new instance of MockShardClientManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockShardClientManager(t interface {
mock.TestingT
Cleanup(func())
}) *MockShardClientManager {
mock := &MockShardClientManager{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -0,0 +1,76 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shardclient
import (
"math/rand"
"github.com/samber/lo"
"go.uber.org/atomic"
)
// shardLeaders wraps shard leader mapping for iteration.
type shardLeaders struct {
idx *atomic.Int64
collectionID int64
shardLeaders map[string][]NodeInfo
}
func (sl *shardLeaders) Get(channel string) []NodeInfo {
return sl.shardLeaders[channel]
}
func (sl *shardLeaders) GetShardLeaderList() []string {
return lo.Keys(sl.shardLeaders)
}
type shardLeadersReader struct {
leaders *shardLeaders
idx int64
}
// Shuffle returns the shuffled shard leader list.
func (it shardLeadersReader) Shuffle() map[string][]NodeInfo {
result := make(map[string][]NodeInfo)
for channel, leaders := range it.leaders.shardLeaders {
l := len(leaders)
// shuffle all replica at random order
shuffled := make([]NodeInfo, l)
for i, randIndex := range rand.Perm(l) {
shuffled[i] = leaders[randIndex]
}
// make each copy has same probability to be first replica
for index, leader := range shuffled {
if leader == leaders[int(it.idx)%l] {
shuffled[0], shuffled[index] = shuffled[index], shuffled[0]
}
}
result[channel] = shuffled
}
return result
}
// GetReader returns shuffer reader for shard leader.
func (sl *shardLeaders) GetReader() shardLeadersReader {
idx := sl.idx.Inc()
return shardLeadersReader{
leaders: sl,
idx: idx,
}
}

View File

@ -13,7 +13,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package shardclient
import (
"context"
@ -32,7 +33,7 @@ func NewRoundRobinBalancer() *RoundRobinBalancer {
return &RoundRobinBalancer{}
}
func (b *RoundRobinBalancer) RegisterNodeInfo(nodeInfos []nodeInfo) {}
func (b *RoundRobinBalancer) RegisterNodeInfo(nodeInfos []NodeInfo) {}
func (b *RoundRobinBalancer) SelectNode(ctx context.Context, availableNodes []int64, cost int64) (int64, error) {
if len(availableNodes) == 0 {

View File

@ -13,7 +13,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
package shardclient
import (
"context"

View File

@ -1,4 +1,20 @@
package proxy
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shardclient
import (
"context"
@ -10,30 +26,31 @@ import (
"go.uber.org/atomic"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/registry"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
type UniqueID = typeutil.UniqueID
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error)
type nodeInfo struct {
nodeID UniqueID
address string
serviceable bool
type NodeInfo struct {
NodeID UniqueID
Address string
Serviceable bool
}
func (n nodeInfo) String() string {
return fmt.Sprintf("<NodeID: %d, serviceable: %v, address: %s>", n.nodeID, n.serviceable, n.address)
func (n NodeInfo) String() string {
return fmt.Sprintf("<NodeID: %d, serviceable: %v, address: %s>", n.NodeID, n.Serviceable, n.Address)
}
var errClosed = errors.New("client is closed")
type shardClient struct {
sync.RWMutex
info nodeInfo
info NodeInfo
poolSize int
clients []types.QueryNodeClient
creator queryNodeCreatorFunc
@ -46,7 +63,7 @@ type shardClient struct {
expiredDuration time.Duration
}
func newShardClient(info nodeInfo, creator queryNodeCreatorFunc, expiredDuration time.Duration) *shardClient {
func newShardClient(info NodeInfo, creator queryNodeCreatorFunc, expiredDuration time.Duration) *shardClient {
return &shardClient{
info: info,
creator: creator,
@ -89,14 +106,14 @@ func (n *shardClient) initClients(ctx context.Context) error {
clients := make([]types.QueryNodeClient, 0, poolSize)
for i := 0; i < poolSize; i++ {
client, err := n.creator(ctx, n.info.address, n.info.nodeID)
client, err := n.creator(ctx, n.info.Address, n.info.NodeID)
if err != nil {
// Roll back already created clients
for _, c := range clients {
c.Close()
}
log.Info("failed to create client for node", zap.Int64("nodeID", n.info.nodeID), zap.Error(err))
return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.nodeID))
log.Info("failed to create client for node", zap.Int64("nodeID", n.info.NodeID), zap.Error(err))
return errors.Wrap(err, fmt.Sprintf("create client for node=%d failed", n.info.NodeID))
}
clients = append(clients, client)
}
@ -150,104 +167,3 @@ func (n *shardClient) close() {
}
n.clients = nil
}
// roundRobinSelectClient selects a client in a round-robin manner
type shardClientMgr interface {
GetClient(ctx context.Context, nodeInfo nodeInfo) (types.QueryNodeClient, error)
SetClientCreatorFunc(creator queryNodeCreatorFunc)
Start()
Close()
}
type shardClientMgrImpl struct {
clients *typeutil.ConcurrentMap[UniqueID, *shardClient]
clientCreator queryNodeCreatorFunc
closeCh chan struct{}
purgeInterval time.Duration
expiredDuration time.Duration
}
const (
defaultPurgeInterval = 600 * time.Second
defaultExpiredDuration = 60 * time.Minute
)
// SessionOpt provides a way to set params in SessionManager
type shardClientMgrOpt func(s shardClientMgr)
func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
return func(s shardClientMgr) { s.SetClientCreatorFunc(creator) }
}
func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return registry.GetInMemoryResolver().ResolveQueryNode(ctx, addr, nodeID)
}
// NewShardClientMgr creates a new shardClientMgr
func newShardClientMgr(options ...shardClientMgrOpt) *shardClientMgrImpl {
s := &shardClientMgrImpl{
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
clientCreator: defaultQueryNodeClientCreator,
closeCh: make(chan struct{}),
purgeInterval: defaultPurgeInterval,
expiredDuration: defaultExpiredDuration,
}
for _, opt := range options {
opt(s)
}
return s
}
func (c *shardClientMgrImpl) SetClientCreatorFunc(creator queryNodeCreatorFunc) {
c.clientCreator = creator
}
func (c *shardClientMgrImpl) GetClient(ctx context.Context, info nodeInfo) (types.QueryNodeClient, error) {
client, _ := c.clients.GetOrInsert(info.nodeID, newShardClient(info, c.clientCreator, c.expiredDuration))
return client.getClient(ctx)
}
// PurgeClient purges client if it is not used for a long time
func (c *shardClientMgrImpl) PurgeClient() {
ticker := time.NewTicker(c.purgeInterval)
defer ticker.Stop()
for {
select {
case <-c.closeCh:
return
case <-ticker.C:
shardLocations := globalMetaCache.ListShardLocation()
c.clients.Range(func(key UniqueID, value *shardClient) bool {
if _, ok := shardLocations[key]; !ok {
// if the client is not used for more than 1 hour, and it's not a delegator anymore, should remove it
if value.isExpired() {
closed := value.Close(false)
if closed {
c.clients.Remove(key)
log.Info("remove idle node client", zap.Int64("nodeID", key))
}
}
}
return true
})
}
}
}
func (c *shardClientMgrImpl) Start() {
go c.PurgeClient()
}
// Close release clients
func (c *shardClientMgrImpl) Close() {
close(c.closeCh)
c.clients.Range(func(key UniqueID, value *shardClient) bool {
value.Close(true)
c.clients.Remove(key)
return true
})
}

View File

@ -0,0 +1,602 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shardclient
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/atomic"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestShardClientMgr(t *testing.T) {
ctx := context.Background()
nodeInfo := NodeInfo{
NodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil)
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
mixcoord := mocks.NewMockMixCoordClient(t)
mgr := NewShardClientMgr(mixcoord)
mgr.SetClientCreatorFunc(creator)
_, err := mgr.GetClient(ctx, nodeInfo)
assert.Nil(t, err)
mgr.Close()
assert.Equal(t, mgr.clients.Len(), 0)
}
func TestShardClient(t *testing.T) {
nodeInfo := NodeInfo{
NodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
shardClient := newShardClient(nodeInfo, creator, 3*time.Second)
assert.Equal(t, len(shardClient.clients), 0)
assert.Equal(t, false, shardClient.initialized.Load())
assert.Equal(t, false, shardClient.isClosed)
ctx := context.Background()
_, err := shardClient.getClient(ctx)
assert.Nil(t, err)
assert.Equal(t, len(shardClient.clients), paramtable.Get().ProxyCfg.QueryNodePoolingSize.GetAsInt())
// test close
closed := shardClient.Close(false)
assert.False(t, closed)
closed = shardClient.Close(true)
assert.True(t, closed)
}
func TestPurgeClient(t *testing.T) {
node := NodeInfo{
NodeID: 1,
}
returnEmptyResult := atomic.NewBool(false)
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
s := &shardClientMgrImpl{
clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
clientCreator: creator,
closeCh: make(chan struct{}),
purgeInterval: 1 * time.Second,
expiredDuration: 3 * time.Second,
collLeader: map[string]map[string]*shardLeaders{
"default": {
"test": {
idx: atomic.NewInt64(0),
collectionID: 1,
shardLeaders: map[string][]NodeInfo{
"0": {node},
},
},
},
},
}
go s.PurgeClient()
defer s.Close()
_, err := s.GetClient(context.Background(), node)
assert.Nil(t, err)
qnClient, ok := s.clients.Get(1)
assert.True(t, ok)
assert.True(t, qnClient.lastActiveTs.Load() > 0)
time.Sleep(2 * time.Second)
// expected client should not been purged before expiredDuration
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() >= 2*time.Second.Nanoseconds())
_, err = s.GetClient(context.Background(), node)
assert.Nil(t, err)
time.Sleep(2 * time.Second)
// GetClient should refresh lastActiveTs, expected client should not be purged
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() < 3*time.Second.Nanoseconds())
time.Sleep(2 * time.Second)
// client reach the expiredDuration, expected client should not be purged
assert.Equal(t, s.clients.Len(), 1)
assert.True(t, time.Now().UnixNano()-qnClient.lastActiveTs.Load() > 3*time.Second.Nanoseconds())
s.DeprecateShardCache("default", "test")
returnEmptyResult.Store(true)
time.Sleep(2 * time.Second)
// remove client from shard location, expected client should be purged
assert.Eventually(t, func() bool {
return s.clients.Len() == 0
}, 10*time.Second, 1*time.Second)
}
func TestDeprecateShardCache(t *testing.T) {
node := NodeInfo{
NodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
mixcoord := mocks.NewMockMixCoordClient(t)
mgr := NewShardClientMgr(mixcoord)
mgr.SetClientCreatorFunc(creator)
t.Run("Clear with no collection info", func(t *testing.T) {
mgr.DeprecateShardCache("default", "collection_not_exist")
// Should not panic or error
})
t.Run("Clear valid collection empty cache", func(t *testing.T) {
// Add a collection to cache first
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"default": {
"test_collection": {
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.DeprecateShardCache("default", "test_collection")
// Verify cache is cleared
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["default"]["test_collection"]
assert.False(t, exists)
})
t.Run("Clear one collection, keep others", func(t *testing.T) {
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"default": {
"collection1": {
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
"collection2": {
idx: atomic.NewInt64(0),
collectionID: 101,
shardLeaders: map[string][]NodeInfo{
"channel-2": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.DeprecateShardCache("default", "collection1")
// Verify collection1 is cleared but collection2 remains
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["default"]["collection1"]
assert.False(t, exists)
_, exists = mgr.collLeader["default"]["collection2"]
assert.True(t, exists)
})
t.Run("Clear last collection in database removes database", func(t *testing.T) {
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"test_db": {
"last_collection": {
idx: atomic.NewInt64(0),
collectionID: 200,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.DeprecateShardCache("test_db", "last_collection")
// Verify database is also removed
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["test_db"]
assert.False(t, exists)
})
mgr.Close()
}
func TestInvalidateShardLeaderCache(t *testing.T) {
node := NodeInfo{
NodeID: 1,
}
qn := mocks.NewMockQueryNodeClient(t)
qn.EXPECT().Close().Return(nil).Maybe()
creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
return qn, nil
}
mixcoord := mocks.NewMockMixCoordClient(t)
mgr := NewShardClientMgr(mixcoord)
mgr.SetClientCreatorFunc(creator)
t.Run("Invalidate single collection", func(t *testing.T) {
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"default": {
"collection1": {
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
"collection2": {
idx: atomic.NewInt64(0),
collectionID: 101,
shardLeaders: map[string][]NodeInfo{
"channel-2": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.InvalidateShardLeaderCache([]int64{100})
// Verify collection with ID 100 is removed, but 101 remains
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["default"]["collection1"]
assert.False(t, exists)
_, exists = mgr.collLeader["default"]["collection2"]
assert.True(t, exists)
})
t.Run("Invalidate multiple collections", func(t *testing.T) {
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"default": {
"collection1": {
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
"collection2": {
idx: atomic.NewInt64(0),
collectionID: 101,
shardLeaders: map[string][]NodeInfo{
"channel-2": {node},
},
},
"collection3": {
idx: atomic.NewInt64(0),
collectionID: 102,
shardLeaders: map[string][]NodeInfo{
"channel-3": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.InvalidateShardLeaderCache([]int64{100, 102})
// Verify collections 100 and 102 are removed, but 101 remains
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["default"]["collection1"]
assert.False(t, exists)
_, exists = mgr.collLeader["default"]["collection2"]
assert.True(t, exists)
_, exists = mgr.collLeader["default"]["collection3"]
assert.False(t, exists)
})
t.Run("Invalidate non-existent collection", func(t *testing.T) {
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"default": {
"collection1": {
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.InvalidateShardLeaderCache([]int64{999})
// Verify collection1 still exists
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["default"]["collection1"]
assert.True(t, exists)
})
t.Run("Invalidate all collections in database removes database", func(t *testing.T) {
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"test_db": {
"collection1": {
idx: atomic.NewInt64(0),
collectionID: 200,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
"collection2": {
idx: atomic.NewInt64(0),
collectionID: 201,
shardLeaders: map[string][]NodeInfo{
"channel-2": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.InvalidateShardLeaderCache([]int64{200, 201})
// Verify database is removed
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["test_db"]
assert.False(t, exists)
})
t.Run("Invalidate across multiple databases", func(t *testing.T) {
mgr.leaderMut.Lock()
mgr.collLeader = map[string]map[string]*shardLeaders{
"db1": {
"collection1": {
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: map[string][]NodeInfo{
"channel-1": {node},
},
},
},
"db2": {
"collection2": {
idx: atomic.NewInt64(0),
collectionID: 100, // Same collection ID in different database
shardLeaders: map[string][]NodeInfo{
"channel-2": {node},
},
},
},
}
mgr.leaderMut.Unlock()
mgr.InvalidateShardLeaderCache([]int64{100})
// Verify collection is removed from both databases
mgr.leaderMut.RLock()
defer mgr.leaderMut.RUnlock()
_, exists := mgr.collLeader["db1"]
assert.False(t, exists) // db1 should be removed
_, exists = mgr.collLeader["db2"]
assert.False(t, exists) // db2 should be removed
})
mgr.Close()
}
func TestShuffleShardLeaders(t *testing.T) {
t.Run("Shuffle with multiple nodes", func(t *testing.T) {
shards := map[string][]NodeInfo{
"channel-1": {
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
},
}
sl := &shardLeaders{
idx: atomic.NewInt64(5),
collectionID: 100,
shardLeaders: shards,
}
reader := sl.GetReader()
result := reader.Shuffle()
// Verify result has same channel
assert.Len(t, result, 1)
assert.Contains(t, result, "channel-1")
// Verify all nodes are present
assert.Len(t, result["channel-1"], 3)
// Verify the first node is based on idx rotation (idx=6, 6%3=0, so nodeID 1 should be first)
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
// Verify all nodes are still present (shuffled)
nodeIDs := make(map[int64]bool)
for _, node := range result["channel-1"] {
nodeIDs[node.NodeID] = true
}
assert.True(t, nodeIDs[1])
assert.True(t, nodeIDs[2])
assert.True(t, nodeIDs[3])
})
t.Run("Shuffle rotates first replica based on idx", func(t *testing.T) {
shards := map[string][]NodeInfo{
"channel-1": {
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
},
}
sl := &shardLeaders{
idx: atomic.NewInt64(5),
collectionID: 100,
shardLeaders: shards,
}
// First read, idx will be 6 (5+1), 6%3=0, so first replica should be leaders[0] which is nodeID 1
reader := sl.GetReader()
result := reader.Shuffle()
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
// Second read, idx will be 7 (6+1), 7%3=1, so first replica should be leaders[1] which is nodeID 2
reader = sl.GetReader()
result = reader.Shuffle()
assert.Equal(t, int64(2), result["channel-1"][0].NodeID)
// Third read, idx will be 8 (7+1), 8%3=2, so first replica should be leaders[2] which is nodeID 3
reader = sl.GetReader()
result = reader.Shuffle()
assert.Equal(t, int64(3), result["channel-1"][0].NodeID)
})
t.Run("Shuffle with single node", func(t *testing.T) {
shards := map[string][]NodeInfo{
"channel-1": {
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
},
}
sl := &shardLeaders{
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: shards,
}
reader := sl.GetReader()
result := reader.Shuffle()
assert.Len(t, result["channel-1"], 1)
assert.Equal(t, int64(1), result["channel-1"][0].NodeID)
})
t.Run("Shuffle with multiple channels", func(t *testing.T) {
shards := map[string][]NodeInfo{
"channel-1": {
{NodeID: 1, Address: "localhost:9000", Serviceable: true},
{NodeID: 2, Address: "localhost:9001", Serviceable: true},
},
"channel-2": {
{NodeID: 3, Address: "localhost:9002", Serviceable: true},
{NodeID: 4, Address: "localhost:9003", Serviceable: true},
},
}
sl := &shardLeaders{
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: shards,
}
reader := sl.GetReader()
result := reader.Shuffle()
// Verify both channels are present
assert.Len(t, result, 2)
assert.Contains(t, result, "channel-1")
assert.Contains(t, result, "channel-2")
// Verify each channel has correct number of nodes
assert.Len(t, result["channel-1"], 2)
assert.Len(t, result["channel-2"], 2)
})
t.Run("Shuffle with empty leaders", func(t *testing.T) {
shards := map[string][]NodeInfo{}
sl := &shardLeaders{
idx: atomic.NewInt64(0),
collectionID: 100,
shardLeaders: shards,
}
reader := sl.GetReader()
result := reader.Shuffle()
assert.Len(t, result, 0)
})
}
// func BenchmarkShardClientMgr(b *testing.B) {
// node := nodeInfo{
// nodeID: 1,
// }
// cache := NewMockCache(b)
// cache.EXPECT().ListShardLocation().Return(map[int64]nodeInfo{
// 1: node,
// }).Maybe()
// globalMetaCache = cache
// qn := mocks.NewMockQueryNodeClient(b)
// qn.EXPECT().Close().Return(nil).Maybe()
// creator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNodeClient, error) {
// return qn, nil
// }
// s := &shardClientMgrImpl{
// clients: typeutil.NewConcurrentMap[UniqueID, *shardClient](),
// clientCreator: creator,
// closeCh: make(chan struct{}),
// purgeInterval: 1 * time.Second,
// expiredDuration: 10 * time.Second,
// }
// go s.PurgeClient()
// defer s.Close()
// b.ResetTimer()
// b.RunParallel(func(pb *testing.PB) {
// for pb.Next() {
// _, err := s.GetClient(context.Background(), node)
// assert.Nil(b, err)
// }
// })
// }

View File

@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -3030,7 +3031,7 @@ type RunAnalyzerTask struct {
collectionID typeutil.UniqueID
fieldID typeutil.UniqueID
dbName string
lb LBPolicy
lb shardclient.LBPolicy
result *milvuspb.RunAnalyzerResponse
}
@ -3122,12 +3123,12 @@ func (t *RunAnalyzerTask) runAnalyzerOnShardleader(ctx context.Context, nodeID i
}
func (t *RunAnalyzerTask) Execute(ctx context.Context) error {
err := t.lb.ExecuteOneChannel(ctx, CollectionWorkLoad{
db: t.dbName,
collectionName: t.GetCollectionName(),
collectionID: t.collectionID,
nq: int64(len(t.GetPlaceholder())),
exec: t.runAnalyzerOnShardleader,
err := t.lb.ExecuteOneChannel(ctx, shardclient.CollectionWorkLoad{
Db: t.dbName,
CollectionName: t.GetCollectionName(),
CollectionID: t.collectionID,
Nq: int64(len(t.GetPlaceholder())),
Exec: t.runAnalyzerOnShardleader,
})
return err

View File

@ -17,6 +17,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/segcore"
@ -249,7 +250,7 @@ type deleteRunner struct {
// for query
msgID int64
ts uint64
lb LBPolicy
lb shardclient.LBPolicy
count atomic.Int64
// task queue
@ -409,7 +410,7 @@ func (dr *deleteRunner) produce(ctx context.Context, primaryKeys *schemapb.IDs,
// getStreamingQueryAndDelteFunc return query function used by LBPolicy
// make sure it concurrent safe
func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) executeFunc {
func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) shardclient.ExecuteFunc {
return func(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", dr.collectionID),
@ -556,12 +557,12 @@ func (dr *deleteRunner) complexDelete(ctx context.Context, plan *planpb.PlanNode
return err
}
err = dr.lb.Execute(ctx, CollectionWorkLoad{
db: dr.req.GetDbName(),
collectionName: dr.req.GetCollectionName(),
collectionID: dr.collectionID,
nq: 1,
exec: dr.getStreamingQueryAndDelteFunc(plan),
err = dr.lb.Execute(ctx, shardclient.CollectionWorkLoad{
Db: dr.req.GetDbName(),
CollectionName: dr.req.GetCollectionName(),
CollectionID: dr.collectionID,
Nq: 1,
Exec: dr.getStreamingQueryAndDelteFunc(plan),
})
dr.result.DeleteCnt = dr.count.Load()
dr.result.Timestamp = dr.sessionTS.Load()

View File

@ -17,6 +17,7 @@ import (
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
@ -683,7 +684,7 @@ func TestDeleteRunner_Run(t *testing.T) {
t.Run("simple delete task failed", func(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
expr := "pk in [1,2,3]"
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
@ -722,7 +723,7 @@ func TestDeleteRunner_Run(t *testing.T) {
t.Run("complex delete query rpc failed", func(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
qn := mocks.NewMockQueryNodeClient(t)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
expr := "pk < 3"
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
require.NoError(t, err)
@ -749,8 +750,8 @@ func TestDeleteRunner_Run(t *testing.T) {
},
plan: plan,
}
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
return workload.Exec(ctx, 1, qn, "")
})
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
@ -764,7 +765,7 @@ func TestDeleteRunner_Run(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
qn := mocks.NewMockQueryNodeClient(t)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
expr := "pk < 3"
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
require.NoError(t, err)
@ -795,8 +796,8 @@ func TestDeleteRunner_Run(t *testing.T) {
}
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
return workload.Exec(ctx, 1, qn, "")
})
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
@ -830,7 +831,7 @@ func TestDeleteRunner_Run(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
qn := mocks.NewMockQueryNodeClient(t)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
expr := "pk < 3"
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
require.NoError(t, err)
@ -860,8 +861,8 @@ func TestDeleteRunner_Run(t *testing.T) {
},
plan: plan,
}
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
return workload.Exec(ctx, 1, qn, "")
})
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
@ -893,7 +894,7 @@ func TestDeleteRunner_Run(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
qn := mocks.NewMockQueryNodeClient(t)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
expr := "pk < 3"
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
require.NoError(t, err)
@ -923,8 +924,8 @@ func TestDeleteRunner_Run(t *testing.T) {
plan: plan,
}
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
return workload.Exec(ctx, 1, qn, "")
})
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
@ -957,7 +958,7 @@ func TestDeleteRunner_Run(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
qn := mocks.NewMockQueryNodeClient(t)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
expr := "pk < 3"
plan, err := planparserv2.CreateRetrievePlan(schema.schemaHelper, expr, nil)
require.NoError(t, err)
@ -987,8 +988,8 @@ func TestDeleteRunner_Run(t *testing.T) {
plan: plan,
}
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
return workload.Exec(ctx, 1, qn, "")
})
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(
@ -1025,7 +1026,7 @@ func TestDeleteRunner_Run(t *testing.T) {
mockMgr := NewMockChannelsMgr(t)
qn := mocks.NewMockQueryNodeClient(t)
lb := NewMockLBPolicy(t)
lb := shardclient.NewMockLBPolicy(t)
mockCache := NewMockCache(t)
mockCache.EXPECT().GetCollectionID(mock.Anything, dbName, collectionName).Return(collectionID, nil).Maybe()
@ -1059,8 +1060,8 @@ func TestDeleteRunner_Run(t *testing.T) {
plan: plan,
}
mockMgr.EXPECT().getChannels(collectionID).Return(channels, nil)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload CollectionWorkLoad) error {
return workload.exec(ctx, 1, qn, "")
lb.EXPECT().Execute(mock.Anything, mock.Anything).Call.Return(func(ctx context.Context, workload shardclient.CollectionWorkLoad) error {
return workload.Exec(ctx, 1, qn, "")
})
qn.EXPECT().QueryStream(mock.Anything, mock.Anything).Call.Return(

View File

@ -85,9 +85,8 @@ func TestGetIndexStateTask_Execute(t *testing.T) {
collectionID: collectionID,
}
shardMgr := newShardClientMgr()
// failed to get collection id.
err := InitMetaCache(ctx, queryCoord, shardMgr)
err := InitMetaCache(ctx, queryCoord)
assert.NoError(t, err)
assert.Error(t, gist.Execute(ctx))
}

View File

@ -1,5 +1,6 @@
package proxy
/*
import (
"context"
@ -7,6 +8,7 @@ import (
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
@ -16,7 +18,7 @@ import (
type queryFunc func(context.Context, UniqueID, types.QueryNodeClient, ...string) error
type pickShardPolicy func(context.Context, shardClientMgr, queryFunc, map[string][]nodeInfo) error
type pickShardPolicy func(context.Context, shardclient.ShardClientMgr, queryFunc, map[string][]nodeInfo) error
var errInvalidShardLeaders = errors.New("Invalid shard leader")
@ -24,7 +26,7 @@ var errInvalidShardLeaders = errors.New("Invalid shard leader")
// if request failed, it finds shard leader for failed dml channels
func RoundRobinPolicy(
ctx context.Context,
mgr shardClientMgr,
mgr shardclient.ShardClientMgr,
query queryFunc,
dml2leaders map[string][]nodeInfo,
) error {
@ -65,3 +67,4 @@ func RoundRobinPolicy(
err := wg.Wait()
return err
}
*/

View File

@ -1,5 +1,6 @@
package proxy
/*
import (
"context"
"sort"
@ -83,3 +84,4 @@ func (m *mockQuery) records() map[UniqueID][]string {
}
return m.queryset
}
*/

View File

@ -17,6 +17,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/accesslog"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/reduce"
@ -69,7 +70,8 @@ type queryTask struct {
plan *planpb.PlanNode
partitionKeyMode bool
lb LBPolicy
shardclientMgr shardclient.ShardClientMgr
lb shardclient.LBPolicy
channelsMvcc map[string]Timestamp
fastSkip bool
@ -575,12 +577,12 @@ func (t *queryTask) Execute(ctx context.Context) error {
zap.String("requestType", "query"))
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collectionID: t.CollectionID,
collectionName: t.collectionName,
nq: 1,
exec: t.queryShard,
err := t.lb.Execute(ctx, shardclient.CollectionWorkLoad{
Db: t.request.GetDbName(),
CollectionID: t.CollectionID,
CollectionName: t.collectionName,
Nq: 1,
Exec: t.queryShard,
})
if err != nil {
log.Warn("fail to execute query", zap.Error(err))
@ -730,13 +732,13 @@ func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.Query
result, err := qn.Query(ctx, req)
if err != nil {
log.Warn("QueryNode query return error", zap.Error(err))
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
t.shardclientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader")
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return errInvalidShardLeaders
t.shardclientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return merr.Error(result.GetStatus())
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode query result error", zap.Any("errorCode", result.GetStatus().GetErrorCode()), zap.String("reason", result.GetStatus().GetReason()))

View File

@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -60,11 +61,16 @@ func TestQueryTask_all(t *testing.T) {
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe()
mgr := NewMockShardClientManager(t)
mgr := shardclient.NewMockShardClientManager(t)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
lb := NewLBPolicyImpl(mgr)
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
{NodeID: 1, Address: "mock_qn", Serviceable: true},
}, nil).Maybe()
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
lb := shardclient.NewLBPolicyImpl(mgr)
err = InitMetaCache(ctx, qc, mgr)
err = InitMetaCache(ctx, qc)
assert.NoError(t, err)
fieldName2Types := map[string]schemapb.DataType{
@ -141,8 +147,9 @@ func TestQueryTask_all(t *testing.T) {
},
},
},
mixCoord: qc,
lb: lb,
mixCoord: qc,
lb: lb,
shardclientMgr: mgr,
}
assert.NoError(t, task.OnEnqueue())
@ -290,9 +297,10 @@ func TestQueryTask_all(t *testing.T) {
},
},
},
mixCoord: qc,
lb: lb,
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
mixCoord: qc,
lb: lb,
shardclientMgr: mgr,
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
}
// simulate scheduler enqueue task
enqueTs := uint64(10000)
@ -341,9 +349,10 @@ func TestQueryTask_all(t *testing.T) {
},
GuaranteeTimestamp: enqueTs,
},
mixCoord: qc,
lb: lb,
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
mixCoord: qc,
lb: lb,
shardclientMgr: mgr,
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
}
qtErr = qt.PreExecute(context.TODO())
assert.Nil(t, qtErr)

View File

@ -19,6 +19,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/accesslog"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/function/embedding"
@ -84,7 +85,8 @@ type searchTask struct {
mixCoord types.MixCoordClient
node types.ProxyComponent
lb LBPolicy
lb shardclient.LBPolicy
shardClientMgr shardclient.ShardClientMgr
queryChannelsTs map[string]Timestamp
queryInfos []*planpb.QueryInfo
relatedDataSize int64
@ -718,12 +720,12 @@ func (t *searchTask) Execute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID()))
defer tr.CtxElapse(ctx, "done")
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collectionID: t.SearchRequest.CollectionID,
collectionName: t.collectionName,
nq: t.Nq,
exec: t.searchShard,
err := t.lb.Execute(ctx, shardclient.CollectionWorkLoad{
Db: t.request.GetDbName(),
CollectionID: t.SearchRequest.CollectionID,
CollectionName: t.collectionName,
Nq: t.Nq,
Exec: t.searchShard,
})
if err != nil {
log.Warn("search execute failed", zap.Error(err))
@ -928,13 +930,14 @@ func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.Que
result, err = qn.Search(ctx, req)
if err != nil {
log.Warn("QueryNode search return error", zap.Error(err))
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
// globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
t.shardClientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader")
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return errInvalidShardLeaders
t.shardClientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return merr.Error(result.GetStatus())
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode search result error",

View File

@ -39,6 +39,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/function/embedding"
@ -67,9 +68,8 @@ func TestSearchTask_PostExecute(t *testing.T) {
)
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(ctx, qc, mgr)
err = InitMetaCache(ctx, qc)
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string) *searchTask {
@ -638,8 +638,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
ctx = context.TODO()
)
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(ctx, qc, mgr)
err = InitMetaCache(ctx, qc)
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string) *searchTask {
@ -1039,8 +1038,7 @@ func TestSearchTask_WithFunctions(t *testing.T) {
)
require.NoError(t, err)
mgr := newShardClientMgr()
err = InitMetaCache(ctx, qc, mgr)
err = InitMetaCache(ctx, qc)
require.NoError(t, err)
getSearchTask := func(t *testing.T, collName string, data []string, withRerank bool) *searchTask {
@ -1105,8 +1103,6 @@ func TestSearchTask_WithFunctions(t *testing.T) {
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(info, nil).Maybe()
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe()
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
globalMetaCache = cache
{
@ -1266,8 +1262,7 @@ func TestSearchTaskV2_Execute(t *testing.T) {
collectionName = t.Name() + funcutil.GenRandomStr()
)
mgr := newShardClientMgr()
err = InitMetaCache(ctx, qc, mgr)
err = InitMetaCache(ctx, qc)
require.NoError(t, err)
defer qc.Close()
@ -2925,13 +2920,18 @@ func TestSearchTask_ErrExecute(t *testing.T) {
collectionName = t.Name() + funcutil.GenRandomStr()
)
mgr := NewMockShardClientManager(t)
mgr := shardclient.NewMockShardClientManager(t)
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
lb := NewLBPolicyImpl(mgr)
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
{NodeID: 1, Address: "mock_qn", Serviceable: true},
}, nil).Maybe()
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
lb := shardclient.NewLBPolicyImpl(mgr)
defer rc.Close()
err = InitMetaCache(ctx, rc, mgr)
err = InitMetaCache(ctx, rc)
assert.NoError(t, err)
fieldName2Types := map[string]schemapb.DataType{
@ -3027,8 +3027,9 @@ func TestSearchTask_ErrExecute(t *testing.T) {
Nq: 2,
DslType: commonpb.DslType_BoolExprV1,
},
mixCoord: rc,
lb: lb,
mixCoord: rc,
lb: lb,
shardClientMgr: mgr,
}
for i := 0; i < len(fieldName2Types); i++ {
task.SearchRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
@ -4066,11 +4067,18 @@ func TestSearchTask_Requery(t *testing.T) {
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(schema, nil).Maybe()
cache.EXPECT().GetPartitions(mock.Anything, mock.Anything, mock.Anything).Return(map[string]int64{"_default": UniqueID(1)}, nil).Maybe()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{}, nil).Maybe()
cache.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]nodeInfo{}, nil).Maybe()
cache.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
cache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{}, nil).Maybe()
globalMetaCache = cache
mgr := shardclient.NewMockShardClientManager(t)
// mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
{NodeID: 1, Address: "mock_qn", Serviceable: true},
}, nil).Maybe()
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
node.shardMgr = mgr
t.Run("Test normal", func(t *testing.T) {
collSchema := constructCollectionSchema(pkField, vecField, dim, collection)
schema := newSchemaInfo(collSchema)
@ -4115,9 +4123,9 @@ func TestSearchTask_Requery(t *testing.T) {
}, nil
})
lb := NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
err = workload.exec(ctx, 0, qn, "")
lb := shardclient.NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload shardclient.CollectionWorkLoad) {
err = workload.Exec(ctx, 0, qn, "")
assert.NoError(t, err)
}).Return(nil)
lb.EXPECT().UpdateCostMetrics(mock.Anything, mock.Anything).Return()
@ -4153,6 +4161,7 @@ func TestSearchTask_Requery(t *testing.T) {
tr: timerecord.NewTimeRecorder("search"),
node: node,
translatedOutputFields: outputFields,
shardClientMgr: mgr,
}
op, err := newRequeryOperator(qt, nil)
assert.NoError(t, err)
@ -4199,9 +4208,9 @@ func TestSearchTask_Requery(t *testing.T) {
qn.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, errors.New("mock err 1"))
lb := NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
_ = workload.exec(ctx, 0, qn, "")
lb := shardclient.NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload shardclient.CollectionWorkLoad) {
_ = workload.Exec(ctx, 0, qn, "")
}).Return(errors.New("mock err 1"))
node.lbPolicy = lb
@ -4216,9 +4225,10 @@ func TestSearchTask_Requery(t *testing.T) {
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
},
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
schema: schema,
tr: timerecord.NewTimeRecorder("search"),
node: node,
shardClientMgr: mgr,
}
op, err := newRequeryOperator(qt, nil)
@ -4235,9 +4245,9 @@ func TestSearchTask_Requery(t *testing.T) {
qn.EXPECT().Query(mock.Anything, mock.Anything).
Return(nil, errors.New("mock err 1"))
lb := NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload CollectionWorkLoad) {
_ = workload.exec(ctx, 0, qn, "")
lb := shardclient.NewMockLBPolicy(t)
lb.EXPECT().Execute(mock.Anything, mock.Anything).Run(func(ctx context.Context, workload shardclient.CollectionWorkLoad) {
_ = workload.Exec(ctx, 0, qn, "")
}).Return(errors.New("mock err 1"))
node.lbPolicy = lb
@ -4265,11 +4275,12 @@ func TestSearchTask_Requery(t *testing.T) {
Ids: resultIDs,
},
},
needRequery: true,
schema: schema,
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
tr: timerecord.NewTimeRecorder("search"),
node: node,
needRequery: true,
schema: schema,
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
tr: timerecord.NewTimeRecorder("search"),
node: node,
shardClientMgr: mgr,
}
scores := make([]float32, rows)
for i := range scores {

View File

@ -12,6 +12,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
@ -54,7 +55,8 @@ type getStatisticsTask struct {
*internalpb.GetStatisticsRequest
resultBuf *typeutil.ConcurrentSet[*internalpb.GetStatisticsResponse]
lb LBPolicy
shardclientMgr shardclient.ShardClientMgr
lb shardclient.LBPolicy
}
func (g *getStatisticsTask) TraceCtx() context.Context {
@ -258,12 +260,12 @@ func (g *getStatisticsTask) getStatisticsFromQueryNode(ctx context.Context) erro
if g.resultBuf == nil {
g.resultBuf = typeutil.NewConcurrentSet[*internalpb.GetStatisticsResponse]()
}
err := g.lb.Execute(ctx, CollectionWorkLoad{
db: g.request.GetDbName(),
collectionID: g.GetStatisticsRequest.CollectionID,
collectionName: g.collectionName,
nq: 1,
exec: g.getStatisticsShard,
err := g.lb.Execute(ctx, shardclient.CollectionWorkLoad{
Db: g.request.GetDbName(),
CollectionID: g.GetStatisticsRequest.CollectionID,
CollectionName: g.collectionName,
Nq: 1,
Exec: g.getStatisticsShard,
})
if err != nil {
return errors.Wrap(err, "failed to statistic")
@ -286,15 +288,15 @@ func (g *getStatisticsTask) getStatisticsShard(ctx context.Context, nodeID int64
zap.Int64("nodeID", nodeID),
zap.String("channel", channel),
zap.Error(err))
globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
g.shardclientMgr.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Ctx(ctx).Warn("QueryNode is not shardLeader",
zap.Int64("nodeID", nodeID),
zap.String("channel", channel))
globalMetaCache.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
return errInvalidShardLeaders
g.shardclientMgr.DeprecateShardCache(g.request.GetDbName(), g.collectionName)
return merr.Error(result.GetStatus())
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Ctx(ctx).Warn("QueryNode statistic result error",

View File

@ -29,6 +29,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
@ -43,7 +44,8 @@ type StatisticTaskSuite struct {
mixc types.MixCoordClient
qn *mocks.MockQueryNodeClient
lb LBPolicy
lb shardclient.LBPolicy
mgr shardclient.ShardClientMgr
collectionName string
collectionID int64
@ -82,11 +84,17 @@ func (s *StatisticTaskSuite) SetupTest() {
s.qn = mocks.NewMockQueryNodeClient(s.T())
s.qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
mgr := NewMockShardClientManager(s.T())
mgr := shardclient.NewMockShardClientManager(s.T())
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(s.qn, nil).Maybe()
s.lb = NewLBPolicyImpl(mgr)
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
{NodeID: 1, Address: "mock_qn", Serviceable: true},
}, nil).Maybe()
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
s.mgr = mgr
s.lb = shardclient.NewLBPolicyImpl(mgr)
err := InitMetaCache(context.Background(), s.mixc, mgr)
err := InitMetaCache(context.Background(), s.mixc)
s.NoError(err)
s.collectionName = "test_statistics_task"
@ -178,8 +186,9 @@ func (s *StatisticTaskSuite) getStatisticsTask(ctx context.Context) *getStatisti
},
CollectionName: s.collectionName,
},
mixc: s.mixc,
lb: s.lb,
mixc: s.mixc,
lb: s.lb,
shardclientMgr: s.mgr,
}
}

View File

@ -473,8 +473,7 @@ func TestAlterCollection_AllowInsertAutoID_Validation(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
root := buildRoot(true)
mgr := newShardClientMgr()
err := InitMetaCache(ctx, root, mgr)
err := InitMetaCache(ctx, root)
assert.NoError(t, err)
task := &alterCollectionTask{
@ -495,8 +494,7 @@ func TestAlterCollection_AllowInsertAutoID_Validation(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
root := buildRoot(false)
mgr := newShardClientMgr()
err := InitMetaCache(ctx, root, mgr)
err := InitMetaCache(ctx, root)
assert.NoError(t, err)
task := &alterCollectionTask{
@ -1602,8 +1600,7 @@ func TestHasCollectionTask(t *testing.T) {
mixc := NewMixCoordMock()
defer mixc.Close()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mixc, mgr)
InitMetaCache(ctx, mixc)
prefix := "TestHasCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1699,8 +1696,7 @@ func TestDescribeCollectionTask(t *testing.T) {
mixc := NewMixCoordMock()
defer mixc.Close()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mixc, mgr)
InitMetaCache(ctx, mixc)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1757,8 +1753,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
mix := NewMixCoordMock()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mix, mgr)
InitMetaCache(ctx, mix)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1816,8 +1811,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) {
mix := NewMixCoordMock()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mix, mgr)
InitMetaCache(ctx, mix)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -1876,8 +1870,7 @@ func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) {
func TestDescribeCollectionTask_ShardsNum2(t *testing.T) {
mix := NewMixCoordMock()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mix, mgr)
InitMetaCache(ctx, mix)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
@ -2238,8 +2231,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
qc := NewMixCoordMock()
ctx := context.Background()
mgr := newShardClientMgr()
err = InitMetaCache(ctx, qc, mgr)
err = InitMetaCache(ctx, qc)
assert.NoError(t, err)
shardsNum := int32(2)
@ -2467,8 +2459,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
ctx := context.Background()
mgr := newShardClientMgr()
err = InitMetaCache(ctx, mixc, mgr)
err = InitMetaCache(ctx, mixc)
assert.NoError(t, err)
shardsNum := int32(2)
@ -3086,9 +3077,8 @@ func Test_loadCollectionTask_Execute(t *testing.T) {
ctx := context.Background()
indexID := int64(1000)
shardMgr := newShardClientMgr()
// failed to get collection id.
_ = InitMetaCache(ctx, rc, shardMgr)
_ = InitMetaCache(ctx, rc)
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{
@ -3212,9 +3202,8 @@ func Test_loadPartitionTask_Execute(t *testing.T) {
ctx := context.Background()
indexID := int64(1000)
shardMgr := newShardClientMgr()
// failed to get collection id.
_ = InitMetaCache(ctx, qc, shardMgr)
_ = InitMetaCache(ctx, qc)
qc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{
@ -3297,8 +3286,7 @@ func TestCreateResourceGroupTask(t *testing.T) {
defer mixc.Close()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mixc, mgr)
InitMetaCache(ctx, mixc)
createRGReq := &milvuspb.CreateResourceGroupRequest{
Base: &commonpb.MsgBase{
@ -3335,8 +3323,7 @@ func TestDropResourceGroupTask(t *testing.T) {
defer mixc.Close()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mixc, mgr)
InitMetaCache(ctx, mixc)
dropRGReq := &milvuspb.DropResourceGroupRequest{
Base: &commonpb.MsgBase{
@ -3372,8 +3359,7 @@ func TestTransferNodeTask(t *testing.T) {
defer mixc.Close()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, mixc, mgr)
InitMetaCache(ctx, mixc)
req := &milvuspb.TransferNodeRequest{
Base: &commonpb.MsgBase{
@ -3410,8 +3396,7 @@ func TestTransferReplicaTask(t *testing.T) {
rc := &MockMixCoordClientInterface{}
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, mgr)
InitMetaCache(ctx, rc)
// make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
@ -3458,8 +3443,7 @@ func TestListResourceGroupsTask(t *testing.T) {
}
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, mgr)
InitMetaCache(ctx, rc)
req := &milvuspb.ListResourceGroupsRequest{
Base: &commonpb.MsgBase{
@ -3508,8 +3492,7 @@ func TestDescribeResourceGroupTask(t *testing.T) {
}
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, mgr)
InitMetaCache(ctx, rc)
// make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
@ -3558,8 +3541,7 @@ func TestDescribeResourceGroupTaskFailed(t *testing.T) {
}
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, mgr)
InitMetaCache(ctx, rc)
// make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
@ -3807,7 +3789,7 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
assert.NoError(t, err)
// check default partitions
err = InitMetaCache(ctx, rc, nil)
err = InitMetaCache(ctx, rc)
assert.NoError(t, err)
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", task.CollectionName)
assert.NoError(t, err)
@ -3886,8 +3868,7 @@ func TestPartitionKey(t *testing.T) {
defer qc.Close()
ctx := context.Background()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, qc, mgr)
err := InitMetaCache(ctx, qc)
assert.NoError(t, err)
shardsNum := common.DefaultShardsNum
@ -4121,8 +4102,7 @@ func TestDefaultPartition(t *testing.T) {
qc := NewMixCoordMock()
ctx := context.Background()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, qc, mgr)
err := InitMetaCache(ctx, qc)
assert.NoError(t, err)
shardsNum := common.DefaultShardsNum
@ -4302,8 +4282,7 @@ func TestClusteringKey(t *testing.T) {
ctx := context.Background()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, qc, mgr)
err := InitMetaCache(ctx, qc)
assert.NoError(t, err)
shardsNum := common.DefaultShardsNum
@ -4439,7 +4418,7 @@ func TestClusteringKey(t *testing.T) {
func TestAlterCollectionCheckLoaded(t *testing.T) {
qc := NewMixCoordMock()
InitMetaCache(context.Background(), qc, nil)
InitMetaCache(context.Background(), qc)
collectionName := "test_alter_collection_check_loaded"
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
@ -4481,8 +4460,7 @@ func TestAlterCollectionCheckLoaded(t *testing.T) {
func TestTaskPartitionKeyIsolation(t *testing.T) {
qc := NewMixCoordMock()
ctx := context.Background()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, qc, mgr)
err := InitMetaCache(ctx, qc)
assert.NoError(t, err)
shardsNum := common.DefaultShardsNum
prefix := "TestPartitionKeyIsolation"
@ -4828,7 +4806,7 @@ func TestInsertForReplicate(t *testing.T) {
func TestAlterCollectionFieldCheckLoaded(t *testing.T) {
qc := NewMixCoordMock()
InitMetaCache(context.Background(), qc, nil)
InitMetaCache(context.Background(), qc)
collectionName := "test_alter_collection_field_check_loaded"
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
@ -4880,7 +4858,7 @@ func TestAlterCollectionFieldCheckLoaded(t *testing.T) {
func TestAlterCollectionField(t *testing.T) {
qc := NewMixCoordMock()
InitMetaCache(context.Background(), qc, nil)
InitMetaCache(context.Background(), qc)
collectionName := "test_alter_collection_field"
// Create collection with string and array fields
@ -5483,7 +5461,7 @@ func TestDescribeCollectionTaskWithStructArrayField(t *testing.T) {
func TestAlterCollection_AllowInsertAutoID_AutoIDFalse(t *testing.T) {
qc := NewMixCoordMock()
InitMetaCache(context.Background(), qc, nil)
InitMetaCache(context.Background(), qc)
ctx := context.Background()
collectionName := "test_alter_allow_insert_autoid_autoid_false"

View File

@ -32,6 +32,7 @@ import (
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
@ -600,7 +601,7 @@ func createTestUpdateTask() *upsertTask {
collectionID: 1001,
node: &Proxy{
mixCoord: mcClient,
lbPolicy: NewLBPolicyImpl(nil),
lbPolicy: shardclient.NewLBPolicyImpl(nil),
},
}