diff --git a/Makefile b/Makefile index 7c599a2de1..04b69123f7 100644 --- a/Makefile +++ b/Makefile @@ -456,7 +456,10 @@ generate-mockery-utils: getdeps $(INSTALL_PATH)/mockery --name=Allocator --dir=internal/tso --output=internal/tso/mocks --filename=allocator.go --with-expecter --structname=Allocator --outpkg=mocktso $(INSTALL_PATH)/mockery --name=SessionInterface --dir=$(PWD)/internal/util/sessionutil --output=$(PWD)/internal/util/sessionutil --filename=mock_session.go --with-expecter --structname=MockSession --inpackage $(INSTALL_PATH)/mockery --name=GrpcClient --dir=$(PWD)/internal/util/grpcclient --output=$(PWD)/internal/mocks --filename=mock_grpc_client.go --with-expecter --structname=MockGrpcClient - + # proxy_client_manager.go + $(INSTALL_PATH)/mockery --name=ProxyClientManagerInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_client_manager.go --with-expecter --structname=MockProxyClientManager --inpackage + $(INSTALL_PATH)/mockery --name=ProxyWatcherInterface --dir=$(PWD)/internal/util/proxyutil --output=$(PWD)/internal/util/proxyutil --filename=mock_proxy_watcher.go --with-expecter --structname=MockProxyWatcher --inpackage + generate-mockery-kv: getdeps $(INSTALL_PATH)/mockery --name=TxnKV --dir=$(PWD)/internal/kv --output=$(PWD)/internal/kv/mocks --filename=txn_kv.go --with-expecter $(INSTALL_PATH)/mockery --name=MetaKv --dir=$(PWD)/internal/kv --output=$(PWD)/internal/kv/mocks --filename=meta_kv.go --with-expecter diff --git a/internal/rootcoord/drop_collection_task.go b/internal/rootcoord/drop_collection_task.go index f35fca1770..457aa47a47 100644 --- a/internal/rootcoord/drop_collection_task.go +++ b/internal/rootcoord/drop_collection_task.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/typeutil" @@ -79,7 +80,7 @@ func (t *dropCollectionTask) Execute(ctx context.Context) error { collectionNames: append(aliases, collMeta.Name), collectionID: collMeta.CollectionID, ts: ts, - opts: []expireCacheOpt{expireCacheWithDropFlag()}, + opts: []proxyutil.ExpireCacheOpt{proxyutil.ExpireCacheWithDropFlag()}, }) redoTask.AddSyncStep(&changeCollectionStateStep{ baseStep: baseStep{core: t.core}, diff --git a/internal/rootcoord/expire_cache.go b/internal/rootcoord/expire_cache.go index df21a36fe8..67934f30e7 100644 --- a/internal/rootcoord/expire_cache.go +++ b/internal/rootcoord/expire_cache.go @@ -19,40 +19,14 @@ package rootcoord import ( "context" - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type expireCacheConfig struct { - withDropFlag bool -} - -func (c expireCacheConfig) apply(req *proxypb.InvalidateCollMetaCacheRequest) { - if !c.withDropFlag { - return - } - if req.GetBase() == nil { - req.Base = commonpbutil.NewMsgBase() - } - req.Base.MsgType = commonpb.MsgType_DropCollection -} - -func defaultExpireCacheConfig() expireCacheConfig { - return expireCacheConfig{withDropFlag: false} -} - -type expireCacheOpt func(c *expireCacheConfig) - -func expireCacheWithDropFlag() expireCacheOpt { - return func(c *expireCacheConfig) { - c.withDropFlag = true - } -} - // ExpireMetaCache will call invalidate collection meta cache -func (c *Core) ExpireMetaCache(ctx context.Context, dbName string, collNames []string, collectionID UniqueID, ts typeutil.Timestamp, opts ...expireCacheOpt) error { +func (c *Core) ExpireMetaCache(ctx context.Context, dbName string, collNames []string, collectionID UniqueID, ts typeutil.Timestamp, opts ...proxyutil.ExpireCacheOpt) error { // if collectionID is specified, invalidate all the collection meta cache with the specified collectionID and return if collectionID != InvalidCollectionID { req := proxypb.InvalidateCollMetaCacheRequest{ diff --git a/internal/rootcoord/expire_cache_test.go b/internal/rootcoord/expire_cache_test.go index 82782c6753..8245444d1a 100644 --- a/internal/rootcoord/expire_cache_test.go +++ b/internal/rootcoord/expire_cache_test.go @@ -23,15 +23,16 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/util/proxyutil" ) func Test_expireCacheConfig_apply(t *testing.T) { - c := defaultExpireCacheConfig() + c := proxyutil.DefaultExpireCacheConfig() req := &proxypb.InvalidateCollMetaCacheRequest{} - c.apply(req) + c.Apply(req) assert.Nil(t, req.GetBase()) - opt := expireCacheWithDropFlag() + opt := proxyutil.ExpireCacheWithDropFlag() opt(&c) - c.apply(req) + c.Apply(req) assert.Equal(t, commonpb.MsgType_DropCollection, req.GetBase().GetMsgType()) } diff --git a/internal/rootcoord/mock_test.go b/internal/rootcoord/mock_test.go index 6e06bf2ec5..3e603b2f88 100644 --- a/internal/rootcoord/mock_test.go +++ b/internal/rootcoord/mock_test.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/tso" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -385,9 +386,7 @@ func newTestCore(opts ...Opt) *Core { func withValidProxyManager() Opt { return func(c *Core) { - c.proxyClientManager = &proxyClientManager{ - proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](), - } + c.proxyClientManager = proxyutil.NewProxyClientManager(proxyutil.DefaultProxyCreator) p := newMockProxy() p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { return merr.Success(), nil @@ -398,15 +397,14 @@ func withValidProxyManager() Opt { Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, }, nil } - c.proxyClientManager.proxyClient.Insert(TestProxyID, p) + clients := c.proxyClientManager.GetProxyClients() + clients.Insert(TestProxyID, p) } } func withInvalidProxyManager() Opt { return func(c *Core) { - c.proxyClientManager = &proxyClientManager{ - proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](), - } + c.proxyClientManager = proxyutil.NewProxyClientManager(proxyutil.DefaultProxyCreator) p := newMockProxy() p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { return merr.Success(), errors.New("error mock InvalidateCollectionMetaCache") @@ -417,7 +415,8 @@ func withInvalidProxyManager() Opt { Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, }, nil } - c.proxyClientManager.proxyClient.Insert(TestProxyID, p) + clients := c.proxyClientManager.GetProxyClients() + clients.Insert(TestProxyID, p) } } diff --git a/internal/rootcoord/proxy_client_manager_test.go b/internal/rootcoord/proxy_client_manager_test.go deleted file mode 100644 index dc3a6dbe17..0000000000 --- a/internal/rootcoord/proxy_client_manager_test.go +++ /dev/null @@ -1,314 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rootcoord - -import ( - "context" - "fmt" - "sync" - "testing" - - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/internal/types" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/pkg/util/etcd" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/paramtable" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -type proxyMock struct { - types.ProxyClient - collArray []string - collIDs []UniqueID - mutex sync.Mutex - - returnError bool - returnGrpcError bool -} - -func (p *proxyMock) Stop() error { - return nil -} - -func (p *proxyMock) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - p.mutex.Lock() - defer p.mutex.Unlock() - if p.returnError { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, nil - } - if p.returnGrpcError { - return nil, fmt.Errorf("grpc error") - } - p.collArray = append(p.collArray, request.CollectionName) - p.collIDs = append(p.collIDs, request.CollectionID) - return merr.Success(), nil -} - -func (p *proxyMock) GetCollArray() []string { - p.mutex.Lock() - defer p.mutex.Unlock() - ret := make([]string, 0, len(p.collArray)) - ret = append(ret, p.collArray...) - return ret -} - -func (p *proxyMock) GetCollIDs() []UniqueID { - p.mutex.Lock() - defer p.mutex.Unlock() - ret := p.collIDs - return ret -} - -func (p *proxyMock) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - if p.returnError { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - }, nil - } - if p.returnGrpcError { - return nil, fmt.Errorf("grpc error") - } - return merr.Success(), nil -} - -func (p *proxyMock) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return merr.Success(), nil -} - -func TestProxyClientManager_AddProxyClients(t *testing.T) { - paramtable.Init() - - core, err := NewCore(context.Background(), nil) - assert.NoError(t, err) - cli, err := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) - defer cli.Close() - assert.NoError(t, err) - core.etcdCli = cli - core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { - return nil, errors.New("failed") - } - - pcm := newProxyClientManager(core.proxyCreator) - - session := &sessionutil.Session{ - SessionRaw: sessionutil.SessionRaw{ - ServerID: 100, - Address: "localhost", - }, - } - - sessions := []*sessionutil.Session{session} - pcm.AddProxyClients(sessions) -} - -func TestProxyClientManager_AddProxyClient(t *testing.T) { - paramtable.Init() - - core, err := NewCore(context.Background(), nil) - assert.NoError(t, err) - cli, err := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) - assert.NoError(t, err) - defer cli.Close() - core.etcdCli = cli - - core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { - return nil, errors.New("failed") - } - - pcm := newProxyClientManager(core.proxyCreator) - - session := &sessionutil.Session{ - SessionRaw: sessionutil.SessionRaw{ - ServerID: 100, - Address: "localhost", - }, - } - - pcm.AddProxyClient(session) -} - -func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { - t.Run("empty proxy list", func(t *testing.T) { - ctx := context.Background() - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) - assert.NoError(t, err) - }) - - t.Run("mock rpc error", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return merr.Success(), errors.New("error mock InvalidateCollectionMetaCache") - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) - assert.Error(t, err) - }) - - t.Run("mock error code", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - mockErr := errors.New("mock error") - p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return merr.Status(mockErr), nil - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) - assert.Error(t, err) - }) - - t.Run("mock proxy service down", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return nil, merr.ErrNodeNotFound - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - - err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) - assert.NoError(t, err) - }) - - t.Run("normal case", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - return merr.Success(), nil - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) - assert.NoError(t, err) - }) -} - -func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { - t.Run("empty proxy list", func(t *testing.T) { - ctx := context.Background() - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) - assert.NoError(t, err) - }) - - t.Run("mock rpc error", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - return merr.Success(), errors.New("error mock InvalidateCredentialCache") - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) - assert.Error(t, err) - }) - - t.Run("mock error code", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - mockErr := errors.New("mock error") - p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - return merr.Status(mockErr), nil - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) - assert.Error(t, err) - }) - - t.Run("normal case", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - return merr.Success(), nil - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) - assert.NoError(t, err) - }) -} - -func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { - t.Run("empty proxy list", func(t *testing.T) { - ctx := context.Background() - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) - assert.NoError(t, err) - }) - - t.Run("mock rpc error", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return merr.Success(), errors.New("error mock RefreshPolicyInfoCache") - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) - assert.Error(t, err) - }) - - t.Run("mock error code", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - mockErr := errors.New("mock error") - p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return merr.Status(mockErr), nil - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) - assert.Error(t, err) - }) - - t.Run("normal case", func(t *testing.T) { - ctx := context.Background() - p1 := newMockProxy() - p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - return merr.Success(), nil - } - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) - err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) - assert.NoError(t, err) - }) -} diff --git a/internal/rootcoord/quota_center.go b/internal/rootcoord/quota_center.go index 2a611f2b4b..a44ba6b628 100644 --- a/internal/rootcoord/quota_center.go +++ b/internal/rootcoord/quota_center.go @@ -34,6 +34,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/tso" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" @@ -86,7 +87,7 @@ type collectionStates = map[milvuspb.QuotaState]commonpb.ErrorCode // If necessary, user can also manually force to deny RW requests. type QuotaCenter struct { // clients - proxies *proxyClientManager + proxies proxyutil.ProxyClientManagerInterface queryCoord types.QueryCoordClient dataCoord types.DataCoordClient meta IMetaTable @@ -113,7 +114,7 @@ type QuotaCenter struct { } // NewQuotaCenter returns a new QuotaCenter. -func NewQuotaCenter(proxies *proxyClientManager, queryCoord types.QueryCoordClient, dataCoord types.DataCoordClient, tsoAllocator tso.Allocator, meta IMetaTable) *QuotaCenter { +func NewQuotaCenter(proxies proxyutil.ProxyClientManagerInterface, queryCoord types.QueryCoordClient, dataCoord types.DataCoordClient, tsoAllocator tso.Allocator, meta IMetaTable) *QuotaCenter { return &QuotaCenter{ proxies: proxies, queryCoord: queryCoord, diff --git a/internal/rootcoord/quota_center_test.go b/internal/rootcoord/quota_center_test.go index 670ec78212..fb5d073465 100644 --- a/internal/rootcoord/quota_center_test.go +++ b/internal/rootcoord/quota_center_test.go @@ -33,7 +33,7 @@ import ( "github.com/milvus-io/milvus/internal/mocks" "github.com/milvus-io/milvus/internal/proto/internalpb" mockrootcoord "github.com/milvus-io/milvus/internal/rootcoord/mocks" - "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" @@ -51,7 +51,8 @@ func TestQuotaCenter(t *testing.T) { assert.NoError(t, err) core.tsoAllocator = newMockTsoAllocator() - pcm := newProxyClientManager(core.proxyCreator) + pcm := proxyutil.NewMockProxyClientManager(t) + pcm.EXPECT().GetProxyMetrics(mock.Anything).Return(nil, nil).Maybe() dc := mocks.NewMockDataCoordClient(t) dc.EXPECT().GetMetrics(mock.Anything, mock.Anything).Return(nil, nil).Maybe() @@ -531,10 +532,8 @@ func TestQuotaCenter(t *testing.T) { t.Run("test setRates", func(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) - p1 := mocks.NewMockProxyClient(t) - p1.EXPECT().SetRates(mock.Anything, mock.Anything).Return(nil, nil) - pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()} - pcm.proxyClient.Insert(TestProxyID, p1) + pcm.EXPECT().GetProxyCount().Return(1) + pcm.EXPECT().SetRates(mock.Anything, mock.Anything).Return(nil) meta := mockrootcoord.NewIMetaTable(t) meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe() quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta) diff --git a/internal/rootcoord/root_coord.go b/internal/rootcoord/root_coord.go index 11a1de267f..4f0e7fe1df 100644 --- a/internal/rootcoord/root_coord.go +++ b/internal/rootcoord/root_coord.go @@ -51,6 +51,7 @@ import ( "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/importutil" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" tsoutil2 "github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/pkg/common" @@ -101,9 +102,9 @@ type Core struct { metaKVCreator metaKVCreator - proxyCreator proxyCreator - proxyManager *proxyManager - proxyClientManager *proxyClientManager + proxyCreator proxyutil.ProxyCreator + proxyWatcher *proxyutil.ProxyWatcher + proxyClientManager proxyutil.ProxyClientManagerInterface metricsCacheManager *metricsinfo.MetricsCacheManager @@ -144,7 +145,7 @@ func NewCore(c context.Context, factory dependency.Factory) (*Core, error) { } core.UpdateStateCode(commonpb.StateCode_Abnormal) - core.SetProxyCreator(DefaultProxyCreator) + core.SetProxyCreator(proxyutil.DefaultProxyCreator) return core, nil } @@ -473,21 +474,20 @@ func (c *Core) initInternal() error { c.chanTimeTick = newTimeTickSync(c.ctx, c.session.ServerID, c.factory, chanMap) log.Info("create TimeTick sync done") - c.proxyClientManager = newProxyClientManager(c.proxyCreator) + c.proxyClientManager = proxyutil.NewProxyClientManager(c.proxyCreator) c.broker = newServerBroker(c) c.ddlTsLockManager = newDdlTsLockManager(c.tsoAllocator) c.garbageCollector = newBgGarbageCollector(c) c.stepExecutor = newBgStepExecutor(c.ctx) - c.proxyManager = newProxyManager( - c.ctx, + c.proxyWatcher = proxyutil.NewProxyWatcher( c.etcdCli, c.chanTimeTick.initSessions, c.proxyClientManager.AddProxyClients, ) - c.proxyManager.AddSessionFunc(c.chanTimeTick.addSession, c.proxyClientManager.AddProxyClient) - c.proxyManager.DelSessionFunc(c.chanTimeTick.delSession, c.proxyClientManager.DelProxyClient) + c.proxyWatcher.AddSessionFunc(c.chanTimeTick.addSession, c.proxyClientManager.AddProxyClient) + c.proxyWatcher.DelSessionFunc(c.chanTimeTick.delSession, c.proxyClientManager.DelProxyClient) log.Info("init proxy manager done") c.metricsCacheManager = metricsinfo.NewMetricsCacheManager() @@ -694,7 +694,7 @@ func (c *Core) restore(ctx context.Context) error { } func (c *Core) startInternal() error { - if err := c.proxyManager.WatchProxy(); err != nil { + if err := c.proxyWatcher.WatchProxy(c.ctx); err != nil { log.Fatal("rootcoord failed to watch proxy", zap.Error(err)) // you can not just stuck here, panic(err) @@ -789,8 +789,8 @@ func (c *Core) Stop() error { c.UpdateStateCode(commonpb.StateCode_Abnormal) c.stopExecutor() c.stopScheduler() - if c.proxyManager != nil { - c.proxyManager.Stop() + if c.proxyWatcher != nil { + c.proxyWatcher.Stop() } c.cancelIfNotNil() if c.quotaCenter != nil { diff --git a/internal/rootcoord/step.go b/internal/rootcoord/step.go index 5a51996a6c..dd194a0d73 100644 --- a/internal/rootcoord/step.go +++ b/internal/rootcoord/step.go @@ -24,6 +24,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/internal/metastore/model" pb "github.com/milvus-io/milvus/internal/proto/etcdpb" + "github.com/milvus-io/milvus/internal/util/proxyutil" ) type stepPriority int @@ -170,7 +171,7 @@ type expireCacheStep struct { collectionNames []string collectionID UniqueID ts Timestamp - opts []expireCacheOpt + opts []proxyutil.ExpireCacheOpt } func (s *expireCacheStep) Execute(ctx context.Context) ([]nestedStep, error) { diff --git a/internal/util/proxyutil/mock_proxy_client_manager.go b/internal/util/proxyutil/mock_proxy_client_manager.go new file mode 100644 index 0000000000..ea53e94d80 --- /dev/null +++ b/internal/util/proxyutil/mock_proxy_client_manager.go @@ -0,0 +1,566 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package proxyutil + +import ( + context "context" + + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + mock "github.com/stretchr/testify/mock" + + proxypb "github.com/milvus-io/milvus/internal/proto/proxypb" + + sessionutil "github.com/milvus-io/milvus/internal/util/sessionutil" + + types "github.com/milvus-io/milvus/internal/types" + + typeutil "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +// MockProxyClientManager is an autogenerated mock type for the ProxyClientManagerInterface type +type MockProxyClientManager struct { + mock.Mock +} + +type MockProxyClientManager_Expecter struct { + mock *mock.Mock +} + +func (_m *MockProxyClientManager) EXPECT() *MockProxyClientManager_Expecter { + return &MockProxyClientManager_Expecter{mock: &_m.Mock} +} + +// AddProxyClient provides a mock function with given fields: session +func (_m *MockProxyClientManager) AddProxyClient(session *sessionutil.Session) { + _m.Called(session) +} + +// MockProxyClientManager_AddProxyClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddProxyClient' +type MockProxyClientManager_AddProxyClient_Call struct { + *mock.Call +} + +// AddProxyClient is a helper method to define mock.On call +// - session *sessionutil.Session +func (_e *MockProxyClientManager_Expecter) AddProxyClient(session interface{}) *MockProxyClientManager_AddProxyClient_Call { + return &MockProxyClientManager_AddProxyClient_Call{Call: _e.mock.On("AddProxyClient", session)} +} + +func (_c *MockProxyClientManager_AddProxyClient_Call) Run(run func(session *sessionutil.Session)) *MockProxyClientManager_AddProxyClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*sessionutil.Session)) + }) + return _c +} + +func (_c *MockProxyClientManager_AddProxyClient_Call) Return() *MockProxyClientManager_AddProxyClient_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProxyClientManager_AddProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_AddProxyClient_Call { + _c.Call.Return(run) + return _c +} + +// AddProxyClients provides a mock function with given fields: session +func (_m *MockProxyClientManager) AddProxyClients(session []*sessionutil.Session) { + _m.Called(session) +} + +// MockProxyClientManager_AddProxyClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddProxyClients' +type MockProxyClientManager_AddProxyClients_Call struct { + *mock.Call +} + +// AddProxyClients is a helper method to define mock.On call +// - session []*sessionutil.Session +func (_e *MockProxyClientManager_Expecter) AddProxyClients(session interface{}) *MockProxyClientManager_AddProxyClients_Call { + return &MockProxyClientManager_AddProxyClients_Call{Call: _e.mock.On("AddProxyClients", session)} +} + +func (_c *MockProxyClientManager_AddProxyClients_Call) Run(run func(session []*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].([]*sessionutil.Session)) + }) + return _c +} + +func (_c *MockProxyClientManager_AddProxyClients_Call) Return() *MockProxyClientManager_AddProxyClients_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProxyClientManager_AddProxyClients_Call) RunAndReturn(run func([]*sessionutil.Session)) *MockProxyClientManager_AddProxyClients_Call { + _c.Call.Return(run) + return _c +} + +// DelProxyClient provides a mock function with given fields: s +func (_m *MockProxyClientManager) DelProxyClient(s *sessionutil.Session) { + _m.Called(s) +} + +// MockProxyClientManager_DelProxyClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DelProxyClient' +type MockProxyClientManager_DelProxyClient_Call struct { + *mock.Call +} + +// DelProxyClient is a helper method to define mock.On call +// - s *sessionutil.Session +func (_e *MockProxyClientManager_Expecter) DelProxyClient(s interface{}) *MockProxyClientManager_DelProxyClient_Call { + return &MockProxyClientManager_DelProxyClient_Call{Call: _e.mock.On("DelProxyClient", s)} +} + +func (_c *MockProxyClientManager_DelProxyClient_Call) Run(run func(s *sessionutil.Session)) *MockProxyClientManager_DelProxyClient_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(*sessionutil.Session)) + }) + return _c +} + +func (_c *MockProxyClientManager_DelProxyClient_Call) Return() *MockProxyClientManager_DelProxyClient_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProxyClientManager_DelProxyClient_Call) RunAndReturn(run func(*sessionutil.Session)) *MockProxyClientManager_DelProxyClient_Call { + _c.Call.Return(run) + return _c +} + +// GetComponentStates provides a mock function with given fields: ctx +func (_m *MockProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) { + ret := _m.Called(ctx) + + var r0 map[int64]*milvuspb.ComponentStates + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (map[int64]*milvuspb.ComponentStates, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) map[int64]*milvuspb.ComponentStates); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[int64]*milvuspb.ComponentStates) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClientManager_GetComponentStates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetComponentStates' +type MockProxyClientManager_GetComponentStates_Call struct { + *mock.Call +} + +// GetComponentStates is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockProxyClientManager_Expecter) GetComponentStates(ctx interface{}) *MockProxyClientManager_GetComponentStates_Call { + return &MockProxyClientManager_GetComponentStates_Call{Call: _e.mock.On("GetComponentStates", ctx)} +} + +func (_c *MockProxyClientManager_GetComponentStates_Call) Run(run func(ctx context.Context)) *MockProxyClientManager_GetComponentStates_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockProxyClientManager_GetComponentStates_Call) Return(_a0 map[int64]*milvuspb.ComponentStates, _a1 error) *MockProxyClientManager_GetComponentStates_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClientManager_GetComponentStates_Call) RunAndReturn(run func(context.Context) (map[int64]*milvuspb.ComponentStates, error)) *MockProxyClientManager_GetComponentStates_Call { + _c.Call.Return(run) + return _c +} + +// GetProxyClients provides a mock function with given fields: +func (_m *MockProxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { + ret := _m.Called() + + var r0 *typeutil.ConcurrentMap[int64, types.ProxyClient] + if rf, ok := ret.Get(0).(func() *typeutil.ConcurrentMap[int64, types.ProxyClient]); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*typeutil.ConcurrentMap[int64, types.ProxyClient]) + } + } + + return r0 +} + +// MockProxyClientManager_GetProxyClients_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetProxyClients' +type MockProxyClientManager_GetProxyClients_Call struct { + *mock.Call +} + +// GetProxyClients is a helper method to define mock.On call +func (_e *MockProxyClientManager_Expecter) GetProxyClients() *MockProxyClientManager_GetProxyClients_Call { + return &MockProxyClientManager_GetProxyClients_Call{Call: _e.mock.On("GetProxyClients")} +} + +func (_c *MockProxyClientManager_GetProxyClients_Call) Run(run func()) *MockProxyClientManager_GetProxyClients_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProxyClientManager_GetProxyClients_Call) Return(_a0 *typeutil.ConcurrentMap[int64, types.ProxyClient]) *MockProxyClientManager_GetProxyClients_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClientManager_GetProxyClients_Call) RunAndReturn(run func() *typeutil.ConcurrentMap[int64, types.ProxyClient]) *MockProxyClientManager_GetProxyClients_Call { + _c.Call.Return(run) + return _c +} + +// GetProxyCount provides a mock function with given fields: +func (_m *MockProxyClientManager) GetProxyCount() int { + ret := _m.Called() + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// MockProxyClientManager_GetProxyCount_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetProxyCount' +type MockProxyClientManager_GetProxyCount_Call struct { + *mock.Call +} + +// GetProxyCount is a helper method to define mock.On call +func (_e *MockProxyClientManager_Expecter) GetProxyCount() *MockProxyClientManager_GetProxyCount_Call { + return &MockProxyClientManager_GetProxyCount_Call{Call: _e.mock.On("GetProxyCount")} +} + +func (_c *MockProxyClientManager_GetProxyCount_Call) Run(run func()) *MockProxyClientManager_GetProxyCount_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProxyClientManager_GetProxyCount_Call) Return(_a0 int) *MockProxyClientManager_GetProxyCount_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClientManager_GetProxyCount_Call) RunAndReturn(run func() int) *MockProxyClientManager_GetProxyCount_Call { + _c.Call.Return(run) + return _c +} + +// GetProxyMetrics provides a mock function with given fields: ctx +func (_m *MockProxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { + ret := _m.Called(ctx) + + var r0 []*milvuspb.GetMetricsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]*milvuspb.GetMetricsResponse, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []*milvuspb.GetMetricsResponse); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*milvuspb.GetMetricsResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockProxyClientManager_GetProxyMetrics_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetProxyMetrics' +type MockProxyClientManager_GetProxyMetrics_Call struct { + *mock.Call +} + +// GetProxyMetrics is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockProxyClientManager_Expecter) GetProxyMetrics(ctx interface{}) *MockProxyClientManager_GetProxyMetrics_Call { + return &MockProxyClientManager_GetProxyMetrics_Call{Call: _e.mock.On("GetProxyMetrics", ctx)} +} + +func (_c *MockProxyClientManager_GetProxyMetrics_Call) Run(run func(ctx context.Context)) *MockProxyClientManager_GetProxyMetrics_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockProxyClientManager_GetProxyMetrics_Call) Return(_a0 []*milvuspb.GetMetricsResponse, _a1 error) *MockProxyClientManager_GetProxyMetrics_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockProxyClientManager_GetProxyMetrics_Call) RunAndReturn(run func(context.Context) ([]*milvuspb.GetMetricsResponse, error)) *MockProxyClientManager_GetProxyMetrics_Call { + _c.Call.Return(run) + return _c +} + +// InvalidateCollectionMetaCache provides a mock function with given fields: ctx, request, opts +func (_m *MockProxyClientManager) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...ExpireCacheOpt) error { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, request) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...ExpireCacheOpt) error); ok { + r0 = rf(ctx, request, opts...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockProxyClientManager_InvalidateCollectionMetaCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateCollectionMetaCache' +type MockProxyClientManager_InvalidateCollectionMetaCache_Call struct { + *mock.Call +} + +// InvalidateCollectionMetaCache is a helper method to define mock.On call +// - ctx context.Context +// - request *proxypb.InvalidateCollMetaCacheRequest +// - opts ...ExpireCacheOpt +func (_e *MockProxyClientManager_Expecter) InvalidateCollectionMetaCache(ctx interface{}, request interface{}, opts ...interface{}) *MockProxyClientManager_InvalidateCollectionMetaCache_Call { + return &MockProxyClientManager_InvalidateCollectionMetaCache_Call{Call: _e.mock.On("InvalidateCollectionMetaCache", + append([]interface{}{ctx, request}, opts...)...)} +} + +func (_c *MockProxyClientManager_InvalidateCollectionMetaCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...ExpireCacheOpt)) *MockProxyClientManager_InvalidateCollectionMetaCache_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]ExpireCacheOpt, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(ExpireCacheOpt) + } + } + run(args[0].(context.Context), args[1].(*proxypb.InvalidateCollMetaCacheRequest), variadicArgs...) + }) + return _c +} + +func (_c *MockProxyClientManager_InvalidateCollectionMetaCache_Call) Return(_a0 error) *MockProxyClientManager_InvalidateCollectionMetaCache_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClientManager_InvalidateCollectionMetaCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateCollMetaCacheRequest, ...ExpireCacheOpt) error) *MockProxyClientManager_InvalidateCollectionMetaCache_Call { + _c.Call.Return(run) + return _c +} + +// InvalidateCredentialCache provides a mock function with given fields: ctx, request +func (_m *MockProxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { + ret := _m.Called(ctx, request) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.InvalidateCredCacheRequest) error); ok { + r0 = rf(ctx, request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockProxyClientManager_InvalidateCredentialCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InvalidateCredentialCache' +type MockProxyClientManager_InvalidateCredentialCache_Call struct { + *mock.Call +} + +// InvalidateCredentialCache is a helper method to define mock.On call +// - ctx context.Context +// - request *proxypb.InvalidateCredCacheRequest +func (_e *MockProxyClientManager_Expecter) InvalidateCredentialCache(ctx interface{}, request interface{}) *MockProxyClientManager_InvalidateCredentialCache_Call { + return &MockProxyClientManager_InvalidateCredentialCache_Call{Call: _e.mock.On("InvalidateCredentialCache", ctx, request)} +} + +func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) Run(run func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest)) *MockProxyClientManager_InvalidateCredentialCache_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.InvalidateCredCacheRequest)) + }) + return _c +} + +func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) Return(_a0 error) *MockProxyClientManager_InvalidateCredentialCache_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClientManager_InvalidateCredentialCache_Call) RunAndReturn(run func(context.Context, *proxypb.InvalidateCredCacheRequest) error) *MockProxyClientManager_InvalidateCredentialCache_Call { + _c.Call.Return(run) + return _c +} + +// RefreshPolicyInfoCache provides a mock function with given fields: ctx, req +func (_m *MockProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { + ret := _m.Called(ctx, req) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error); ok { + r0 = rf(ctx, req) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockProxyClientManager_RefreshPolicyInfoCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RefreshPolicyInfoCache' +type MockProxyClientManager_RefreshPolicyInfoCache_Call struct { + *mock.Call +} + +// RefreshPolicyInfoCache is a helper method to define mock.On call +// - ctx context.Context +// - req *proxypb.RefreshPolicyInfoCacheRequest +func (_e *MockProxyClientManager_Expecter) RefreshPolicyInfoCache(ctx interface{}, req interface{}) *MockProxyClientManager_RefreshPolicyInfoCache_Call { + return &MockProxyClientManager_RefreshPolicyInfoCache_Call{Call: _e.mock.On("RefreshPolicyInfoCache", ctx, req)} +} + +func (_c *MockProxyClientManager_RefreshPolicyInfoCache_Call) Run(run func(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest)) *MockProxyClientManager_RefreshPolicyInfoCache_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.RefreshPolicyInfoCacheRequest)) + }) + return _c +} + +func (_c *MockProxyClientManager_RefreshPolicyInfoCache_Call) Return(_a0 error) *MockProxyClientManager_RefreshPolicyInfoCache_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClientManager_RefreshPolicyInfoCache_Call) RunAndReturn(run func(context.Context, *proxypb.RefreshPolicyInfoCacheRequest) error) *MockProxyClientManager_RefreshPolicyInfoCache_Call { + _c.Call.Return(run) + return _c +} + +// SetRates provides a mock function with given fields: ctx, request +func (_m *MockProxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { + ret := _m.Called(ctx, request) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.SetRatesRequest) error); ok { + r0 = rf(ctx, request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockProxyClientManager_SetRates_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetRates' +type MockProxyClientManager_SetRates_Call struct { + *mock.Call +} + +// SetRates is a helper method to define mock.On call +// - ctx context.Context +// - request *proxypb.SetRatesRequest +func (_e *MockProxyClientManager_Expecter) SetRates(ctx interface{}, request interface{}) *MockProxyClientManager_SetRates_Call { + return &MockProxyClientManager_SetRates_Call{Call: _e.mock.On("SetRates", ctx, request)} +} + +func (_c *MockProxyClientManager_SetRates_Call) Run(run func(ctx context.Context, request *proxypb.SetRatesRequest)) *MockProxyClientManager_SetRates_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.SetRatesRequest)) + }) + return _c +} + +func (_c *MockProxyClientManager_SetRates_Call) Return(_a0 error) *MockProxyClientManager_SetRates_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClientManager_SetRates_Call) RunAndReturn(run func(context.Context, *proxypb.SetRatesRequest) error) *MockProxyClientManager_SetRates_Call { + _c.Call.Return(run) + return _c +} + +// UpdateCredentialCache provides a mock function with given fields: ctx, request +func (_m *MockProxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { + ret := _m.Called(ctx, request) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *proxypb.UpdateCredCacheRequest) error); ok { + r0 = rf(ctx, request) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockProxyClientManager_UpdateCredentialCache_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCredentialCache' +type MockProxyClientManager_UpdateCredentialCache_Call struct { + *mock.Call +} + +// UpdateCredentialCache is a helper method to define mock.On call +// - ctx context.Context +// - request *proxypb.UpdateCredCacheRequest +func (_e *MockProxyClientManager_Expecter) UpdateCredentialCache(ctx interface{}, request interface{}) *MockProxyClientManager_UpdateCredentialCache_Call { + return &MockProxyClientManager_UpdateCredentialCache_Call{Call: _e.mock.On("UpdateCredentialCache", ctx, request)} +} + +func (_c *MockProxyClientManager_UpdateCredentialCache_Call) Run(run func(ctx context.Context, request *proxypb.UpdateCredCacheRequest)) *MockProxyClientManager_UpdateCredentialCache_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(*proxypb.UpdateCredCacheRequest)) + }) + return _c +} + +func (_c *MockProxyClientManager_UpdateCredentialCache_Call) Return(_a0 error) *MockProxyClientManager_UpdateCredentialCache_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyClientManager_UpdateCredentialCache_Call) RunAndReturn(run func(context.Context, *proxypb.UpdateCredCacheRequest) error) *MockProxyClientManager_UpdateCredentialCache_Call { + _c.Call.Return(run) + return _c +} + +// NewMockProxyClientManager creates a new instance of MockProxyClientManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockProxyClientManager(t interface { + mock.TestingT + Cleanup(func()) +}) *MockProxyClientManager { + mock := &MockProxyClientManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/util/proxyutil/mock_proxy_watcher.go b/internal/util/proxyutil/mock_proxy_watcher.go new file mode 100644 index 0000000000..0aed0cb5b6 --- /dev/null +++ b/internal/util/proxyutil/mock_proxy_watcher.go @@ -0,0 +1,203 @@ +// Code generated by mockery v2.32.4. DO NOT EDIT. + +package proxyutil + +import ( + context "context" + + sessionutil "github.com/milvus-io/milvus/internal/util/sessionutil" + mock "github.com/stretchr/testify/mock" +) + +// MockProxyWatcher is an autogenerated mock type for the ProxyWatcherInterface type +type MockProxyWatcher struct { + mock.Mock +} + +type MockProxyWatcher_Expecter struct { + mock *mock.Mock +} + +func (_m *MockProxyWatcher) EXPECT() *MockProxyWatcher_Expecter { + return &MockProxyWatcher_Expecter{mock: &_m.Mock} +} + +// AddSessionFunc provides a mock function with given fields: fns +func (_m *MockProxyWatcher) AddSessionFunc(fns ...func(*sessionutil.Session)) { + _va := make([]interface{}, len(fns)) + for _i := range fns { + _va[_i] = fns[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// MockProxyWatcher_AddSessionFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddSessionFunc' +type MockProxyWatcher_AddSessionFunc_Call struct { + *mock.Call +} + +// AddSessionFunc is a helper method to define mock.On call +// - fns ...func(*sessionutil.Session) +func (_e *MockProxyWatcher_Expecter) AddSessionFunc(fns ...interface{}) *MockProxyWatcher_AddSessionFunc_Call { + return &MockProxyWatcher_AddSessionFunc_Call{Call: _e.mock.On("AddSessionFunc", + append([]interface{}{}, fns...)...)} +} + +func (_c *MockProxyWatcher_AddSessionFunc_Call) Run(run func(fns ...func(*sessionutil.Session))) *MockProxyWatcher_AddSessionFunc_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*sessionutil.Session), len(args)-0) + for i, a := range args[0:] { + if a != nil { + variadicArgs[i] = a.(func(*sessionutil.Session)) + } + } + run(variadicArgs...) + }) + return _c +} + +func (_c *MockProxyWatcher_AddSessionFunc_Call) Return() *MockProxyWatcher_AddSessionFunc_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProxyWatcher_AddSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_AddSessionFunc_Call { + _c.Call.Return(run) + return _c +} + +// DelSessionFunc provides a mock function with given fields: fns +func (_m *MockProxyWatcher) DelSessionFunc(fns ...func(*sessionutil.Session)) { + _va := make([]interface{}, len(fns)) + for _i := range fns { + _va[_i] = fns[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + _m.Called(_ca...) +} + +// MockProxyWatcher_DelSessionFunc_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DelSessionFunc' +type MockProxyWatcher_DelSessionFunc_Call struct { + *mock.Call +} + +// DelSessionFunc is a helper method to define mock.On call +// - fns ...func(*sessionutil.Session) +func (_e *MockProxyWatcher_Expecter) DelSessionFunc(fns ...interface{}) *MockProxyWatcher_DelSessionFunc_Call { + return &MockProxyWatcher_DelSessionFunc_Call{Call: _e.mock.On("DelSessionFunc", + append([]interface{}{}, fns...)...)} +} + +func (_c *MockProxyWatcher_DelSessionFunc_Call) Run(run func(fns ...func(*sessionutil.Session))) *MockProxyWatcher_DelSessionFunc_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]func(*sessionutil.Session), len(args)-0) + for i, a := range args[0:] { + if a != nil { + variadicArgs[i] = a.(func(*sessionutil.Session)) + } + } + run(variadicArgs...) + }) + return _c +} + +func (_c *MockProxyWatcher_DelSessionFunc_Call) Return() *MockProxyWatcher_DelSessionFunc_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProxyWatcher_DelSessionFunc_Call) RunAndReturn(run func(...func(*sessionutil.Session))) *MockProxyWatcher_DelSessionFunc_Call { + _c.Call.Return(run) + return _c +} + +// Stop provides a mock function with given fields: +func (_m *MockProxyWatcher) Stop() { + _m.Called() +} + +// MockProxyWatcher_Stop_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Stop' +type MockProxyWatcher_Stop_Call struct { + *mock.Call +} + +// Stop is a helper method to define mock.On call +func (_e *MockProxyWatcher_Expecter) Stop() *MockProxyWatcher_Stop_Call { + return &MockProxyWatcher_Stop_Call{Call: _e.mock.On("Stop")} +} + +func (_c *MockProxyWatcher_Stop_Call) Run(run func()) *MockProxyWatcher_Stop_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockProxyWatcher_Stop_Call) Return() *MockProxyWatcher_Stop_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProxyWatcher_Stop_Call) RunAndReturn(run func()) *MockProxyWatcher_Stop_Call { + _c.Call.Return(run) + return _c +} + +// WatchProxy provides a mock function with given fields: ctx +func (_m *MockProxyWatcher) WatchProxy(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockProxyWatcher_WatchProxy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WatchProxy' +type MockProxyWatcher_WatchProxy_Call struct { + *mock.Call +} + +// WatchProxy is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockProxyWatcher_Expecter) WatchProxy(ctx interface{}) *MockProxyWatcher_WatchProxy_Call { + return &MockProxyWatcher_WatchProxy_Call{Call: _e.mock.On("WatchProxy", ctx)} +} + +func (_c *MockProxyWatcher_WatchProxy_Call) Run(run func(ctx context.Context)) *MockProxyWatcher_WatchProxy_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockProxyWatcher_WatchProxy_Call) Return(_a0 error) *MockProxyWatcher_WatchProxy_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockProxyWatcher_WatchProxy_Call) RunAndReturn(run func(context.Context) error) *MockProxyWatcher_WatchProxy_Call { + _c.Call.Return(run) + return _c +} + +// NewMockProxyWatcher creates a new instance of MockProxyWatcher. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockProxyWatcher(t interface { + mock.TestingT + Cleanup(func()) +}) *MockProxyWatcher { + mock := &MockProxyWatcher{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/rootcoord/proxy_client_manager.go b/internal/util/proxyutil/proxy_client_manager.go similarity index 71% rename from internal/rootcoord/proxy_client_manager.go rename to internal/util/proxyutil/proxy_client_manager.go index 0141ef4a55..bcdce057c2 100644 --- a/internal/rootcoord/proxy_client_manager.go +++ b/internal/util/proxyutil/proxy_client_manager.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rootcoord +package proxyutil import ( "context" @@ -33,12 +33,39 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" "github.com/milvus-io/milvus/pkg/util/metricsinfo" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -type proxyCreator func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) +type ExpireCacheConfig struct { + withDropFlag bool +} + +func (c ExpireCacheConfig) Apply(req *proxypb.InvalidateCollMetaCacheRequest) { + if !c.withDropFlag { + return + } + if req.GetBase() == nil { + req.Base = commonpbutil.NewMsgBase() + } + req.Base.MsgType = commonpb.MsgType_DropCollection +} + +func DefaultExpireCacheConfig() ExpireCacheConfig { + return ExpireCacheConfig{withDropFlag: false} +} + +type ExpireCacheOpt func(c *ExpireCacheConfig) + +func ExpireCacheWithDropFlag() ExpireCacheOpt { + return func(c *ExpireCacheConfig) { + c.withDropFlag = true + } +} + +type ProxyCreator func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) func DefaultProxyCreator(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { cli, err := grpcproxyclient.NewClient(ctx, addr, nodeID) @@ -48,39 +75,55 @@ func DefaultProxyCreator(ctx context.Context, addr string, nodeID int64) (types. return cli, nil } -type proxyClientManager struct { - creator proxyCreator - proxyClient *typeutil.ConcurrentMap[int64, types.ProxyClient] - helper proxyClientManagerHelper -} - -type proxyClientManagerHelper struct { +type ProxyClientManagerHelper struct { afterConnect func() } -var defaultClientManagerHelper = proxyClientManagerHelper{ +var defaultClientManagerHelper = ProxyClientManagerHelper{ afterConnect: func() {}, } -func newProxyClientManager(creator proxyCreator) *proxyClientManager { - return &proxyClientManager{ +type ProxyClientManagerInterface interface { + AddProxyClient(session *sessionutil.Session) + AddProxyClients(session []*sessionutil.Session) + GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] + DelProxyClient(s *sessionutil.Session) + GetProxyCount() int + + InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...ExpireCacheOpt) error + InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error + UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error + RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error + GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) + SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error + GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) +} + +type ProxyClientManager struct { + creator ProxyCreator + proxyClient *typeutil.ConcurrentMap[int64, types.ProxyClient] + helper ProxyClientManagerHelper +} + +func NewProxyClientManager(creator ProxyCreator) *ProxyClientManager { + return &ProxyClientManager{ creator: creator, proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](), helper: defaultClientManagerHelper, } } -func (p *proxyClientManager) AddProxyClients(sessions []*sessionutil.Session) { +func (p *ProxyClientManager) AddProxyClients(sessions []*sessionutil.Session) { for _, session := range sessions { p.AddProxyClient(session) } } -func (p *proxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { +func (p *ProxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] { return p.proxyClient } -func (p *proxyClientManager) AddProxyClient(session *sessionutil.Session) { +func (p *ProxyClientManager) AddProxyClient(session *sessionutil.Session) { _, ok := p.proxyClient.Get(session.ServerID) if ok { return @@ -91,16 +134,16 @@ func (p *proxyClientManager) AddProxyClient(session *sessionutil.Session) { } // GetProxyCount returns number of proxy clients. -func (p *proxyClientManager) GetProxyCount() int { +func (p *ProxyClientManager) GetProxyCount() int { return p.proxyClient.Len() } // mutex.Lock is required before calling this method. -func (p *proxyClientManager) updateProxyNumMetric() { +func (p *ProxyClientManager) updateProxyNumMetric() { metrics.RootCoordProxyCounter.WithLabelValues().Set(float64(p.proxyClient.Len())) } -func (p *proxyClientManager) connect(session *sessionutil.Session) { +func (p *ProxyClientManager) connect(session *sessionutil.Session) { pc, err := p.creator(context.Background(), session.Address, session.ServerID) if err != nil { log.Warn("failed to create proxy client", zap.String("address", session.Address), zap.Int64("serverID", session.ServerID), zap.Error(err)) @@ -116,7 +159,7 @@ func (p *proxyClientManager) connect(session *sessionutil.Session) { p.helper.afterConnect() } -func (p *proxyClientManager) DelProxyClient(s *sessionutil.Session) { +func (p *ProxyClientManager) DelProxyClient(s *sessionutil.Session) { cli, ok := p.proxyClient.GetAndRemove(s.GetServerID()) if ok { cli.Close() @@ -126,12 +169,12 @@ func (p *proxyClientManager) DelProxyClient(s *sessionutil.Session) { log.Info("remove proxy client", zap.String("proxy address", s.Address), zap.Int64("proxy id", s.ServerID)) } -func (p *proxyClientManager) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...expireCacheOpt) error { - c := defaultExpireCacheConfig() +func (p *ProxyClientManager) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...ExpireCacheOpt) error { + c := DefaultExpireCacheConfig() for _, opt := range opts { opt(&c) } - c.apply(request) + c.Apply(request) if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, InvalidateCollectionMetaCache will not send to any client") @@ -161,7 +204,7 @@ func (p *proxyClientManager) InvalidateCollectionMetaCache(ctx context.Context, } // InvalidateCredentialCache TODO: too many codes similar to InvalidateCollectionMetaCache. -func (p *proxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { +func (p *ProxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error { if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, InvalidateCredentialCache will not send to any client") return nil @@ -187,7 +230,7 @@ func (p *proxyClientManager) InvalidateCredentialCache(ctx context.Context, requ } // UpdateCredentialCache TODO: too many codes similar to InvalidateCollectionMetaCache. -func (p *proxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { +func (p *ProxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error { if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, UpdateCredentialCache will not send to any client") return nil @@ -212,7 +255,7 @@ func (p *proxyClientManager) UpdateCredentialCache(ctx context.Context, request } // RefreshPolicyInfoCache TODO: too many codes similar to InvalidateCollectionMetaCache. -func (p *proxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { +func (p *ProxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error { if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, RefreshPrivilegeInfoCache will not send to any client") return nil @@ -237,7 +280,7 @@ func (p *proxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *pr } // GetProxyMetrics sends requests to proxies to get metrics. -func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { +func (p *ProxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) { if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, GetMetrics will not send to any client") return nil, nil @@ -276,7 +319,7 @@ func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.G } // SetRates notifies Proxy to limit rates of requests. -func (p *proxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { +func (p *ProxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error { if p.proxyClient.Len() == 0 { log.Warn("proxy client is empty, SetRates will not send to any client") return nil @@ -299,3 +342,27 @@ func (p *proxyClientManager) SetRates(ctx context.Context, request *proxypb.SetR }) return group.Wait() } + +func (p *ProxyClientManager) GetComponentStates(ctx context.Context) (map[int64]*milvuspb.ComponentStates, error) { + group, ctx := errgroup.WithContext(ctx) + states := make(map[int64]*milvuspb.ComponentStates) + + p.proxyClient.Range(func(key int64, value types.ProxyClient) bool { + k, v := key, value + group.Go(func() error { + sta, err := v.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + if err != nil { + return err + } + states[k] = sta + return nil + }) + return true + }) + err := group.Wait() + if err != nil { + return nil, err + } + + return states, nil +} diff --git a/internal/util/proxyutil/proxy_client_manager_test.go b/internal/util/proxyutil/proxy_client_manager_test.go new file mode 100644 index 0000000000..98b024e959 --- /dev/null +++ b/internal/util/proxyutil/proxy_client_manager_test.go @@ -0,0 +1,426 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxyutil + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + milvuspb "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +type UniqueID = int64 + +var ( + Params = paramtable.Get() + TestProxyID = int64(1) +) + +type proxyMock struct { + types.ProxyClient + collArray []string + collIDs []UniqueID + mutex sync.Mutex + + returnError bool + returnGrpcError bool +} + +func (p *proxyMock) Stop() error { + return nil +} + +func (p *proxyMock) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { + p.mutex.Lock() + defer p.mutex.Unlock() + if p.returnError { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, nil + } + if p.returnGrpcError { + return nil, fmt.Errorf("grpc error") + } + p.collArray = append(p.collArray, request.CollectionName) + p.collIDs = append(p.collIDs, request.CollectionID) + return merr.Success(), nil +} + +func (p *proxyMock) GetCollArray() []string { + p.mutex.Lock() + defer p.mutex.Unlock() + ret := make([]string, 0, len(p.collArray)) + ret = append(ret, p.collArray...) + return ret +} + +func (p *proxyMock) GetCollIDs() []UniqueID { + p.mutex.Lock() + defer p.mutex.Unlock() + ret := p.collIDs + return ret +} + +func (p *proxyMock) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { + if p.returnError { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + }, nil + } + if p.returnGrpcError { + return nil, fmt.Errorf("grpc error") + } + return merr.Success(), nil +} + +func (p *proxyMock) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { + return merr.Success(), nil +} + +func TestProxyClientManager_AddProxyClients(t *testing.T) { + proxyCreator := func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { + return nil, errors.New("failed") + } + + pcm := NewProxyClientManager(proxyCreator) + + session := &sessionutil.Session{ + SessionRaw: sessionutil.SessionRaw{ + ServerID: 100, + Address: "localhost", + }, + } + + sessions := []*sessionutil.Session{session} + pcm.AddProxyClients(sessions) +} + +func TestProxyClientManager_AddProxyClient(t *testing.T) { + proxyCreator := func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error) { + return nil, errors.New("failed") + } + + pcm := NewProxyClientManager(proxyCreator) + + session := &sessionutil.Session{ + SessionRaw: sessionutil.SessionRaw{ + ServerID: 100, + Address: "localhost", + }, + } + + pcm.AddProxyClient(session) +} + +func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) { + t.Run("empty proxy list", func(t *testing.T) { + ctx := context.Background() + pcm := NewProxyClientManager(DefaultProxyCreator) + err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + assert.NoError(t, err) + }) + + t.Run("mock rpc error", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything).Return(merr.Success(), errors.New("error mock InvalidateCollectionMetaCache")) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("mock error code", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything).Return(merr.Status(errors.New("mock error")), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("mock proxy service down", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything).Return(nil, merr.ErrNodeNotFound) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + assert.NoError(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything).Return(merr.Success(), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{}) + assert.NoError(t, err) + }) +} + +func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) { + t.Run("empty proxy list", func(t *testing.T) { + ctx := context.Background() + pcm := NewProxyClientManager(DefaultProxyCreator) + err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) + assert.NoError(t, err) + }) + + t.Run("mock rpc error", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().InvalidateCredentialCache(mock.Anything, mock.Anything).Return(merr.Success(), errors.New("error mock InvalidateCredentialCache")) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("mock error code", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + mockErr := errors.New("mock error") + p1.EXPECT().InvalidateCredentialCache(mock.Anything, mock.Anything).Return(merr.Status(mockErr), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().InvalidateCredentialCache(mock.Anything, mock.Anything).Return(merr.Success(), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{}) + assert.NoError(t, err) + }) +} + +func TestProxyClientManager_UpdateCredentialCache(t *testing.T) { + TestProxyID := int64(1001) + t.Run("empty proxy list", func(t *testing.T) { + ctx := context.Background() + pcm := NewProxyClientManager(DefaultProxyCreator) + + err := pcm.UpdateCredentialCache(ctx, &proxypb.UpdateCredCacheRequest{}) + assert.NoError(t, err) + }) + + t.Run("mock rpc error", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().UpdateCredentialCache(mock.Anything, mock.Anything).Return(merr.Success(), errors.New("error mock InvalidateCredentialCache")) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.UpdateCredentialCache(ctx, &proxypb.UpdateCredCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("mock error code", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + mockErr := errors.New("mock error") + p1.EXPECT().UpdateCredentialCache(mock.Anything, mock.Anything).Return(merr.Status(mockErr), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.UpdateCredentialCache(ctx, &proxypb.UpdateCredCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().UpdateCredentialCache(mock.Anything, mock.Anything).Return(merr.Success(), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.UpdateCredentialCache(ctx, &proxypb.UpdateCredCacheRequest{}) + assert.NoError(t, err) + }) +} + +func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) { + t.Run("empty proxy list", func(t *testing.T) { + ctx := context.Background() + pcm := NewProxyClientManager(DefaultProxyCreator) + + err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) + assert.NoError(t, err) + }) + + t.Run("mock rpc error", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().RefreshPolicyInfoCache(mock.Anything, mock.Anything).Return(merr.Success(), errors.New("error mock RefreshPolicyInfoCache")) + + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("mock error code", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().RefreshPolicyInfoCache(mock.Anything, mock.Anything).Return(merr.Status(errors.New("mock error")), nil) + + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + + p1.EXPECT().RefreshPolicyInfoCache(mock.Anything, mock.Anything).Return(merr.Success(), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{}) + assert.NoError(t, err) + }) +} + +func TestProxyClientManager_GetProxyMetrics(t *testing.T) { + TestProxyID := int64(1001) + t.Run("empty proxy list", func(t *testing.T) { + ctx := context.Background() + pcm := NewProxyClientManager(DefaultProxyCreator) + _, err := pcm.GetProxyMetrics(ctx) + assert.NoError(t, err) + }) + + t.Run("mock rpc error", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().GetProxyMetrics(mock.Anything, mock.Anything).Return(nil, errors.New("error mock InvalidateCredentialCache")) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + _, err := pcm.GetProxyMetrics(ctx) + assert.Error(t, err) + }) + + t.Run("mock error code", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + mockErr := errors.New("mock error") + p1.EXPECT().GetProxyMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{Status: merr.Status(mockErr)}, nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + _, err := pcm.GetProxyMetrics(ctx) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().GetProxyMetrics(mock.Anything, mock.Anything).Return(&milvuspb.GetMetricsResponse{Status: merr.Success()}, nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + _, err := pcm.GetProxyMetrics(ctx) + assert.NoError(t, err) + }) +} + +func TestProxyClientManager_SetRates(t *testing.T) { + TestProxyID := int64(1001) + t.Run("empty proxy list", func(t *testing.T) { + ctx := context.Background() + pcm := NewProxyClientManager(DefaultProxyCreator) + err := pcm.SetRates(ctx, &proxypb.SetRatesRequest{}) + assert.NoError(t, err) + }) + + t.Run("mock rpc error", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().SetRates(mock.Anything, mock.Anything).Return(nil, errors.New("error mock InvalidateCredentialCache")) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.SetRates(ctx, &proxypb.SetRatesRequest{}) + assert.Error(t, err) + }) + + t.Run("mock error code", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + mockErr := errors.New("mock error") + p1.EXPECT().SetRates(mock.Anything, mock.Anything).Return(merr.Status(mockErr), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.SetRates(ctx, &proxypb.SetRatesRequest{}) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().SetRates(mock.Anything, mock.Anything).Return(merr.Success(), nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + err := pcm.SetRates(ctx, &proxypb.SetRatesRequest{}) + assert.NoError(t, err) + }) +} + +func TestProxyClientManager_GetComponentStates(t *testing.T) { + TestProxyID := int64(1001) + t.Run("empty proxy list", func(t *testing.T) { + ctx := context.Background() + pcm := NewProxyClientManager(DefaultProxyCreator) + _, err := pcm.GetComponentStates(ctx) + assert.NoError(t, err) + }) + + t.Run("mock rpc error", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, errors.New("error mock InvalidateCredentialCache")) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + _, err := pcm.GetComponentStates(ctx) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + ctx := context.Background() + p1 := mocks.NewMockProxyClient(t) + p1.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(&milvuspb.ComponentStates{Status: merr.Success()}, nil) + pcm := NewProxyClientManager(DefaultProxyCreator) + pcm.proxyClient.Insert(TestProxyID, p1) + _, err := pcm.GetComponentStates(ctx) + assert.NoError(t, err) + }) +} diff --git a/internal/rootcoord/proxy_manager.go b/internal/util/proxyutil/proxy_watcher.go similarity index 73% rename from internal/rootcoord/proxy_manager.go rename to internal/util/proxyutil/proxy_watcher.go index 3724d7029d..0cd81bada6 100644 --- a/internal/rootcoord/proxy_manager.go +++ b/internal/util/proxyutil/proxy_watcher.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rootcoord +package proxyutil import ( "context" @@ -32,56 +32,62 @@ import ( "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) -// proxyManager manages proxy connected to the rootcoord -type proxyManager struct { - ctx context.Context - cancel context.CancelFunc +type ProxyWatcherInterface interface { + AddSessionFunc(fns ...func(*sessionutil.Session)) + DelSessionFunc(fns ...func(*sessionutil.Session)) + + WatchProxy(ctx context.Context) error + Stop() +} + +// ProxyWatcher manages proxy clients +type ProxyWatcher struct { wg errgroup.Group lock sync.Mutex etcdCli *clientv3.Client initSessionsFunc []func([]*sessionutil.Session) addSessionsFunc []func(*sessionutil.Session) delSessionsFunc []func(*sessionutil.Session) + + closeCh chan struct{} } -// newProxyManager helper function to create a proxyManager -// etcdEndpoints is the address list of etcd +// NewProxyWatcher helper function to create a proxyWatcher // fns are the custom getSessions function list -func newProxyManager(ctx context.Context, client *clientv3.Client, fns ...func([]*sessionutil.Session)) *proxyManager { - ctx, cancel := context.WithCancel(ctx) - p := &proxyManager{ - ctx: ctx, - cancel: cancel, +func NewProxyWatcher(client *clientv3.Client, fns ...func([]*sessionutil.Session)) *ProxyWatcher { + p := &ProxyWatcher{ lock: sync.Mutex{}, etcdCli: client, + closeCh: make(chan struct{}), } p.initSessionsFunc = append(p.initSessionsFunc, fns...) return p } // AddSessionFunc adds functions to addSessions function list -func (p *proxyManager) AddSessionFunc(fns ...func(*sessionutil.Session)) { +func (p *ProxyWatcher) AddSessionFunc(fns ...func(*sessionutil.Session)) { p.lock.Lock() defer p.lock.Unlock() p.addSessionsFunc = append(p.addSessionsFunc, fns...) } // DelSessionFunc add functions to delSessions function list -func (p *proxyManager) DelSessionFunc(fns ...func(*sessionutil.Session)) { +func (p *ProxyWatcher) DelSessionFunc(fns ...func(*sessionutil.Session)) { p.lock.Lock() defer p.lock.Unlock() p.delSessionsFunc = append(p.delSessionsFunc, fns...) } // WatchProxy starts a goroutine to watch proxy session changes on etcd -func (p *proxyManager) WatchProxy() error { - ctx, cancel := context.WithTimeout(p.ctx, Params.ServiceParam.EtcdCfg.RequestTimeout.GetAsDuration(time.Millisecond)) +func (p *ProxyWatcher) WatchProxy(ctx context.Context) error { + childCtx, cancel := context.WithTimeout(ctx, paramtable.Get().ServiceParam.EtcdCfg.RequestTimeout.GetAsDuration(time.Millisecond)) defer cancel() - sessions, rev, err := p.getSessionsOnEtcd(ctx) + sessions, rev, err := p.getSessionsOnEtcd(childCtx) if err != nil { return err } @@ -92,8 +98,8 @@ func (p *proxyManager) WatchProxy() error { } eventCh := p.etcdCli.Watch( - p.ctx, - path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot, typeutil.ProxyRole), + ctx, + path.Join(paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot, typeutil.ProxyRole), clientv3.WithPrefix(), clientv3.WithCreatedNotify(), clientv3.WithPrevKV(), @@ -101,20 +107,24 @@ func (p *proxyManager) WatchProxy() error { ) p.wg.Go(func() error { - p.startWatchEtcd(p.ctx, eventCh) + p.startWatchEtcd(ctx, eventCh) return nil }) return nil } -func (p *proxyManager) startWatchEtcd(ctx context.Context, eventCh clientv3.WatchChan) { +func (p *ProxyWatcher) startWatchEtcd(ctx context.Context, eventCh clientv3.WatchChan) { log.Info("start to watch etcd") for { select { case <-ctx.Done(): log.Warn("stop watching etcd loop") return - // TODO @xiaocai2333: watch proxy by session WatchService. + + case <-p.closeCh: + log.Warn("stop watching etcd loop") + return + case event, ok := <-eventCh: if !ok { log.Warn("stop watching etcd loop due to closed etcd event channel") @@ -122,7 +132,7 @@ func (p *proxyManager) startWatchEtcd(ctx context.Context, eventCh clientv3.Watc } if err := event.Err(); err != nil { if err == v3rpc.ErrCompacted { - err2 := p.WatchProxy() + err2 := p.WatchProxy(ctx) if err2 != nil { log.Error("re watch proxy fails when etcd has a compaction error", zap.Error(err), zap.Error(err2)) @@ -149,7 +159,7 @@ func (p *proxyManager) startWatchEtcd(ctx context.Context, eventCh clientv3.Watc } } -func (p *proxyManager) handlePutEvent(e *clientv3.Event) error { +func (p *ProxyWatcher) handlePutEvent(e *clientv3.Event) error { session, err := p.parseSession(e.Kv.Value) if err != nil { return err @@ -161,7 +171,7 @@ func (p *proxyManager) handlePutEvent(e *clientv3.Event) error { return nil } -func (p *proxyManager) handleDeleteEvent(e *clientv3.Event) error { +func (p *ProxyWatcher) handleDeleteEvent(e *clientv3.Event) error { session, err := p.parseSession(e.PrevKv.Value) if err != nil { return err @@ -173,7 +183,7 @@ func (p *proxyManager) handleDeleteEvent(e *clientv3.Event) error { return nil } -func (p *proxyManager) parseSession(value []byte) (*sessionutil.Session, error) { +func (p *ProxyWatcher) parseSession(value []byte) (*sessionutil.Session, error) { session := new(sessionutil.Session) err := json.Unmarshal(value, session) if err != nil { @@ -182,10 +192,10 @@ func (p *proxyManager) parseSession(value []byte) (*sessionutil.Session, error) return session, nil } -func (p *proxyManager) getSessionsOnEtcd(ctx context.Context) ([]*sessionutil.Session, int64, error) { +func (p *ProxyWatcher) getSessionsOnEtcd(ctx context.Context) ([]*sessionutil.Session, int64, error) { resp, err := p.etcdCli.Get( ctx, - path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot, typeutil.ProxyRole), + path.Join(paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot, typeutil.ProxyRole), clientv3.WithPrefix(), clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend), ) @@ -206,8 +216,8 @@ func (p *proxyManager) getSessionsOnEtcd(ctx context.Context) ([]*sessionutil.Se return sessions, resp.Header.Revision, nil } -// Stop stops the proxyManager -func (p *proxyManager) Stop() { - p.cancel() +// Stop stops the ProxyManager +func (p *ProxyWatcher) Stop() { + close(p.closeCh) p.wg.Wait() } diff --git a/internal/rootcoord/proxy_manager_test.go b/internal/util/proxyutil/proxy_watcher_test.go similarity index 77% rename from internal/rootcoord/proxy_manager_test.go rename to internal/util/proxyutil/proxy_watcher_test.go index c60310d414..f1f4a684dd 100644 --- a/internal/rootcoord/proxy_manager_test.go +++ b/internal/util/proxyutil/proxy_watcher_test.go @@ -14,7 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rootcoord +package proxyutil import ( "context" @@ -37,19 +37,19 @@ func TestProxyManager(t *testing.T) { paramtable.Init() etcdCli, err := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) + paramtable.Get().EtcdCfg.UseEmbedEtcd.GetAsBool(), + paramtable.Get().EtcdCfg.EtcdUseSSL.GetAsBool(), + paramtable.Get().EtcdCfg.Endpoints.GetAsStrings(), + paramtable.Get().EtcdCfg.EtcdTLSCert.GetValue(), + paramtable.Get().EtcdCfg.EtcdTLSKey.GetValue(), + paramtable.Get().EtcdCfg.EtcdTLSCACert.GetValue(), + paramtable.Get().EtcdCfg.EtcdTLSMinVersion.GetValue()) assert.NoError(t, err) defer etcdCli.Close() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - sessKey := path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) + sessKey := path.Join(paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix()) defer etcdCli.Delete(ctx, sessKey, clientv3.WithPrefix()) s1 := sessionutil.Session{ @@ -76,7 +76,7 @@ func TestProxyManager(t *testing.T) { assert.Equal(t, int64(99), sess[1].ServerID) t.Log("get sessions", sess[0], sess[1]) } - pm := newProxyManager(ctx, etcdCli, f1) + pm := NewProxyWatcher(etcdCli, f1) assert.NoError(t, err) fa := func(sess *sessionutil.Session) { assert.Equal(t, int64(101), sess.ServerID) @@ -89,7 +89,7 @@ func TestProxyManager(t *testing.T) { pm.AddSessionFunc(fa) pm.DelSessionFunc(fd) - err = pm.WatchProxy() + err = pm.WatchProxy(ctx) assert.NoError(t, err) t.Log("======== start watch proxy ==========") @@ -113,27 +113,27 @@ func TestProxyManager_ErrCompacted(t *testing.T) { paramtable.Init() etcdCli, err := etcd.GetEtcdClient( - Params.EtcdCfg.UseEmbedEtcd.GetAsBool(), - Params.EtcdCfg.EtcdUseSSL.GetAsBool(), - Params.EtcdCfg.Endpoints.GetAsStrings(), - Params.EtcdCfg.EtcdTLSCert.GetValue(), - Params.EtcdCfg.EtcdTLSKey.GetValue(), - Params.EtcdCfg.EtcdTLSCACert.GetValue(), - Params.EtcdCfg.EtcdTLSMinVersion.GetValue()) + paramtable.Get().EtcdCfg.UseEmbedEtcd.GetAsBool(), + paramtable.Get().EtcdCfg.EtcdUseSSL.GetAsBool(), + paramtable.Get().EtcdCfg.Endpoints.GetAsStrings(), + paramtable.Get().EtcdCfg.EtcdTLSCert.GetValue(), + paramtable.Get().EtcdCfg.EtcdTLSKey.GetValue(), + paramtable.Get().EtcdCfg.EtcdTLSCACert.GetValue(), + paramtable.Get().EtcdCfg.EtcdTLSMinVersion.GetValue()) assert.NoError(t, err) defer etcdCli.Close() ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second) defer cancel() - sessKey := path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) + sessKey := path.Join(paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot) f1 := func(sess []*sessionutil.Session) { t.Log("get sessions num", len(sess)) } - pm := newProxyManager(ctx, etcdCli, f1) + pm := NewProxyWatcher(etcdCli, f1) eventCh := pm.etcdCli.Watch( - pm.ctx, - path.Join(Params.EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot, typeutil.ProxyRole), + ctx, + path.Join(paramtable.Get().EtcdCfg.MetaRootPath.GetValue(), sessionutil.DefaultServiceRoot, typeutil.ProxyRole), clientv3.WithPrefix(), clientv3.WithCreatedNotify(), clientv3.WithPrevKV(), @@ -152,7 +152,7 @@ func TestProxyManager_ErrCompacted(t *testing.T) { etcdCli.Compact(ctx, 10) assert.Panics(t, func() { - pm.startWatchEtcd(pm.ctx, eventCh) + pm.startWatchEtcd(ctx, eventCh) }) for i := 1; i < 10; i++ { diff --git a/scripts/run_go_unittest.sh b/scripts/run_go_unittest.sh index aea157c0ba..40722787c8 100755 --- a/scripts/run_go_unittest.sh +++ b/scripts/run_go_unittest.sh @@ -109,6 +109,7 @@ go test -race -cover -tags dynamic "${PKG_DIR}/util/retry/..." -failfast -count= go test -race -cover -tags dynamic "${MILVUS_DIR}/util/sessionutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic "${MILVUS_DIR}/util/typeutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" go test -race -cover -tags dynamic "${MILVUS_DIR}/util/importutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" +go test -race -cover -tags dynamic "${MILVUS_DIR}/util/proxyutil/..." -failfast -count=1 -ldflags="-r ${RPATH}" } function test_pkg()