fix: data race in ProxyClientManager (#29206)

this PR changed the ProxyClientManager to thread-safe
fix #29205

Signed-off-by: yah01 <yang.cen@zilliz.com>
This commit is contained in:
yah01 2023-12-14 18:22:39 +08:00 committed by GitHub
parent bd640754ac
commit b8674811cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 110 deletions

View File

@ -386,7 +386,7 @@ func newTestCore(opts ...Opt) *Core {
func withValidProxyManager() Opt {
return func(c *Core) {
c.proxyClientManager = &proxyClientManager{
proxyClient: make(map[UniqueID]types.ProxyClient),
proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](),
}
p := newMockProxy()
p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
@ -398,14 +398,14 @@ func withValidProxyManager() Opt {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil
}
c.proxyClientManager.proxyClient[TestProxyID] = p
c.proxyClientManager.proxyClient.Insert(TestProxyID, p)
}
}
func withInvalidProxyManager() Opt {
return func(c *Core) {
c.proxyClientManager = &proxyClientManager{
proxyClient: make(map[UniqueID]types.ProxyClient),
proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](),
}
p := newMockProxy()
p.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
@ -417,7 +417,7 @@ func withInvalidProxyManager() Opt {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
}, nil
}
c.proxyClientManager.proxyClient[TestProxyID] = p
c.proxyClientManager.proxyClient.Insert(TestProxyID, p)
}
}

View File

@ -35,6 +35,7 @@ import (
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type proxyCreator func(ctx context.Context, addr string, nodeID int64) (types.ProxyClient, error)
@ -49,8 +50,7 @@ func DefaultProxyCreator(ctx context.Context, addr string, nodeID int64) (types.
type proxyClientManager struct {
creator proxyCreator
lock sync.RWMutex
proxyClient map[int64]types.ProxyClient
proxyClient *typeutil.ConcurrentMap[int64, types.ProxyClient]
helper proxyClientManagerHelper
}
@ -65,7 +65,7 @@ var defaultClientManagerHelper = proxyClientManagerHelper{
func newProxyClientManager(creator proxyCreator) *proxyClientManager {
return &proxyClientManager{
creator: creator,
proxyClient: make(map[int64]types.ProxyClient),
proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient](),
helper: defaultClientManagerHelper,
}
}
@ -76,16 +76,12 @@ func (p *proxyClientManager) AddProxyClients(sessions []*sessionutil.Session) {
}
}
func (p *proxyClientManager) GetProxyClients() map[int64]types.ProxyClient {
p.lock.RLock()
defer p.lock.RUnlock()
func (p *proxyClientManager) GetProxyClients() *typeutil.ConcurrentMap[int64, types.ProxyClient] {
return p.proxyClient
}
func (p *proxyClientManager) AddProxyClient(session *sessionutil.Session) {
p.lock.RLock()
_, ok := p.proxyClient[session.ServerID]
p.lock.RUnlock()
_, ok := p.proxyClient.Get(session.ServerID)
if ok {
return
}
@ -96,15 +92,12 @@ func (p *proxyClientManager) AddProxyClient(session *sessionutil.Session) {
// GetProxyCount returns number of proxy clients.
func (p *proxyClientManager) GetProxyCount() int {
p.lock.Lock()
defer p.lock.Unlock()
return len(p.proxyClient)
return p.proxyClient.Len()
}
// mutex.Lock is required before calling this method.
func (p *proxyClientManager) updateProxyNumMetric() {
metrics.RootCoordProxyCounter.WithLabelValues().Set(float64(len(p.proxyClient)))
metrics.RootCoordProxyCounter.WithLabelValues().Set(float64(p.proxyClient.Len()))
}
func (p *proxyClientManager) connect(session *sessionutil.Session) {
@ -114,51 +107,40 @@ func (p *proxyClientManager) connect(session *sessionutil.Session) {
return
}
p.lock.Lock()
defer p.lock.Unlock()
_, ok := p.proxyClient[session.ServerID]
_, ok := p.proxyClient.GetOrInsert(session.GetServerID(), pc)
if ok {
pc.Close()
return
}
p.proxyClient[session.ServerID] = pc
log.Info("succeed to create proxy client", zap.String("address", session.Address), zap.Int64("serverID", session.ServerID))
p.helper.afterConnect()
}
func (p *proxyClientManager) DelProxyClient(s *sessionutil.Session) {
p.lock.Lock()
defer p.lock.Unlock()
cli, ok := p.proxyClient[s.ServerID]
cli, ok := p.proxyClient.GetAndRemove(s.GetServerID())
if ok {
cli.Close()
}
delete(p.proxyClient, s.ServerID)
p.updateProxyNumMetric()
log.Info("remove proxy client", zap.String("proxy address", s.Address), zap.Int64("proxy id", s.ServerID))
}
func (p *proxyClientManager) InvalidateCollectionMetaCache(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest, opts ...expireCacheOpt) error {
p.lock.Lock()
defer p.lock.Unlock()
c := defaultExpireCacheConfig()
for _, opt := range opts {
opt(&c)
}
c.apply(request)
if len(p.proxyClient) == 0 {
if p.proxyClient.Len() == 0 {
log.Warn("proxy client is empty, InvalidateCollectionMetaCache will not send to any client")
return nil
}
group := &errgroup.Group{}
for k, v := range p.proxyClient {
k, v := k, v
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
k, v := key, value
group.Go(func() error {
sta, err := v.InvalidateCollectionMetaCache(ctx, request)
if err != nil {
@ -173,23 +155,21 @@ func (p *proxyClientManager) InvalidateCollectionMetaCache(ctx context.Context,
}
return nil
})
}
return true
})
return group.Wait()
}
// InvalidateCredentialCache TODO: too many codes similar to InvalidateCollectionMetaCache.
func (p *proxyClientManager) InvalidateCredentialCache(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) error {
p.lock.Lock()
defer p.lock.Unlock()
if len(p.proxyClient) == 0 {
if p.proxyClient.Len() == 0 {
log.Warn("proxy client is empty, InvalidateCredentialCache will not send to any client")
return nil
}
group := &errgroup.Group{}
for k, v := range p.proxyClient {
k, v := k, v
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
k, v := key, value
group.Go(func() error {
sta, err := v.InvalidateCredentialCache(ctx, request)
if err != nil {
@ -200,23 +180,22 @@ func (p *proxyClientManager) InvalidateCredentialCache(ctx context.Context, requ
}
return nil
})
}
return true
})
return group.Wait()
}
// UpdateCredentialCache TODO: too many codes similar to InvalidateCollectionMetaCache.
func (p *proxyClientManager) UpdateCredentialCache(ctx context.Context, request *proxypb.UpdateCredCacheRequest) error {
p.lock.Lock()
defer p.lock.Unlock()
if len(p.proxyClient) == 0 {
if p.proxyClient.Len() == 0 {
log.Warn("proxy client is empty, UpdateCredentialCache will not send to any client")
return nil
}
group := &errgroup.Group{}
for k, v := range p.proxyClient {
k, v := k, v
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
k, v := key, value
group.Go(func() error {
sta, err := v.UpdateCredentialCache(ctx, request)
if err != nil {
@ -227,23 +206,21 @@ func (p *proxyClientManager) UpdateCredentialCache(ctx context.Context, request
}
return nil
})
}
return true
})
return group.Wait()
}
// RefreshPolicyInfoCache TODO: too many codes similar to InvalidateCollectionMetaCache.
func (p *proxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) error {
p.lock.Lock()
defer p.lock.Unlock()
if len(p.proxyClient) == 0 {
if p.proxyClient.Len() == 0 {
log.Warn("proxy client is empty, RefreshPrivilegeInfoCache will not send to any client")
return nil
}
group := &errgroup.Group{}
for k, v := range p.proxyClient {
k, v := k, v
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
k, v := key, value
group.Go(func() error {
status, err := v.RefreshPolicyInfoCache(ctx, req)
if err != nil {
@ -254,16 +231,14 @@ func (p *proxyClientManager) RefreshPolicyInfoCache(ctx context.Context, req *pr
}
return nil
})
}
return true
})
return group.Wait()
}
// GetProxyMetrics sends requests to proxies to get metrics.
func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.GetMetricsResponse, error) {
p.lock.Lock()
defer p.lock.Unlock()
if len(p.proxyClient) == 0 {
if p.proxyClient.Len() == 0 {
log.Warn("proxy client is empty, GetMetrics will not send to any client")
return nil, nil
}
@ -276,8 +251,8 @@ func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.G
group := &errgroup.Group{}
var metricRspsMu sync.Mutex
metricRsps := make([]*milvuspb.GetMetricsResponse, 0)
for k, v := range p.proxyClient {
k, v := k, v
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
k, v := key, value
group.Go(func() error {
rsp, err := v.GetProxyMetrics(ctx, req)
if err != nil {
@ -291,7 +266,8 @@ func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.G
metricRspsMu.Unlock()
return nil
})
}
return true
})
err = group.Wait()
if err != nil {
return nil, err
@ -301,17 +277,14 @@ func (p *proxyClientManager) GetProxyMetrics(ctx context.Context) ([]*milvuspb.G
// SetRates notifies Proxy to limit rates of requests.
func (p *proxyClientManager) SetRates(ctx context.Context, request *proxypb.SetRatesRequest) error {
p.lock.Lock()
defer p.lock.Unlock()
if len(p.proxyClient) == 0 {
if p.proxyClient.Len() == 0 {
log.Warn("proxy client is empty, SetRates will not send to any client")
return nil
}
group := &errgroup.Group{}
for k, v := range p.proxyClient {
k, v := k, v
p.proxyClient.Range(func(key int64, value types.ProxyClient) bool {
k, v := key, value
group.Go(func() error {
sta, err := v.SetRates(ctx, request)
if err != nil {
@ -322,6 +295,7 @@ func (p *proxyClientManager) SetRates(ctx context.Context, request *proxypb.SetR
}
return nil
})
}
return true
})
return group.Wait()
}

View File

@ -32,6 +32,7 @@ import (
"github.com/milvus-io/milvus/pkg/util/etcd"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/typeutil"
)
type proxyMock struct {
@ -164,7 +165,7 @@ func TestProxyClientManager_AddProxyClient(t *testing.T) {
func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) {
t.Run("empty proxy list", func(t *testing.T) {
ctx := context.Background()
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{})
assert.NoError(t, err)
})
@ -175,9 +176,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) {
p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
return merr.Success(), errors.New("error mock InvalidateCollectionMetaCache")
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{})
assert.Error(t, err)
})
@ -189,9 +189,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) {
p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
return merr.Status(mockErr), nil
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{})
assert.Error(t, err)
})
@ -202,9 +201,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) {
p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
return nil, merr.ErrNodeNotFound
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{})
assert.NoError(t, err)
@ -216,9 +214,8 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) {
p1.InvalidateCollectionMetaCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) {
return merr.Success(), nil
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{})
assert.NoError(t, err)
})
@ -227,7 +224,7 @@ func TestProxyClientManager_InvalidateCollectionMetaCache(t *testing.T) {
func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) {
t.Run("empty proxy list", func(t *testing.T) {
ctx := context.Background()
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{})
assert.NoError(t, err)
})
@ -238,9 +235,8 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) {
p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) {
return merr.Success(), errors.New("error mock InvalidateCredentialCache")
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{})
assert.Error(t, err)
})
@ -252,9 +248,8 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) {
p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) {
return merr.Status(mockErr), nil
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{})
assert.Error(t, err)
})
@ -265,9 +260,8 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) {
p1.InvalidateCredentialCacheFunc = func(ctx context.Context, request *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) {
return merr.Success(), nil
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.InvalidateCredentialCache(ctx, &proxypb.InvalidateCredCacheRequest{})
assert.NoError(t, err)
})
@ -276,7 +270,7 @@ func TestProxyClientManager_InvalidateCredentialCache(t *testing.T) {
func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) {
t.Run("empty proxy list", func(t *testing.T) {
ctx := context.Background()
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{})
assert.NoError(t, err)
})
@ -287,9 +281,8 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) {
p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) {
return merr.Success(), errors.New("error mock RefreshPolicyInfoCache")
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{})
assert.Error(t, err)
})
@ -301,9 +294,8 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) {
p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) {
return merr.Status(mockErr), nil
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{})
assert.Error(t, err)
})
@ -314,9 +306,8 @@ func TestProxyClientManager_RefreshPolicyInfoCache(t *testing.T) {
p1.RefreshPolicyInfoCacheFunc = func(ctx context.Context, request *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) {
return merr.Success(), nil
}
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
err := pcm.RefreshPolicyInfoCache(ctx, &proxypb.RefreshPolicyInfoCacheRequest{})
assert.NoError(t, err)
})

View File

@ -533,9 +533,8 @@ func TestQuotaCenter(t *testing.T) {
qc := mocks.NewMockQueryCoordClient(t)
p1 := mocks.NewMockProxyClient(t)
p1.EXPECT().SetRates(mock.Anything, mock.Anything).Return(nil, nil)
pcm := &proxyClientManager{proxyClient: map[int64]types.ProxyClient{
TestProxyID: p1,
}}
pcm := &proxyClientManager{proxyClient: typeutil.NewConcurrentMap[int64, types.ProxyClient]()}
pcm.proxyClient.Insert(TestProxyID, p1)
meta := mockrootcoord.NewIMetaTable(t)
meta.EXPECT().GetCollectionByID(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, merr.ErrCollectionNotFound).Maybe()
quotaCenter := NewQuotaCenter(pcm, qc, dc, core.tsoAllocator, meta)

View File

@ -2782,9 +2782,10 @@ func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest)
group, ctx := errgroup.WithContext(ctx)
errReasons := make([]string, 0, c.proxyClientManager.GetProxyCount())
for nodeID, proxyClient := range c.proxyClientManager.GetProxyClients() {
nodeID := nodeID
proxyClient := proxyClient
proxyClients := c.proxyClientManager.GetProxyClients()
proxyClients.Range(func(key int64, value types.ProxyClient) bool {
nodeID := key
proxyClient := value
group.Go(func() error {
sta, err := proxyClient.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{})
if err != nil {
@ -2799,7 +2800,8 @@ func (c *Core) CheckHealth(ctx context.Context, in *milvuspb.CheckHealthRequest)
}
return nil
})
}
return true
})
err := group.Wait()
if err != nil || len(errReasons) != 0 {