fix: wrong context using by session of grpc client (#46183)

issue: #46182

Signed-off-by: chyezh <chyezh@outlook.com>
This commit is contained in:
Zhen Ye 2025-12-08 21:47:12 +08:00 committed by GitHub
parent a042a6e1e8
commit 459425ac84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 51 additions and 91 deletions

View File

@ -46,7 +46,7 @@ func NewRunner(ctx context.Context, cfg *configs.Config) *Runner {
func (r *Runner) watchByPrefix(prefix string) {
defer r.wg.Done()
_, revision, err := r.session.GetSessions(prefix)
_, revision, err := r.session.GetSessions(r.ctx, prefix)
fn := func() { r.Stop() }
console.AbnormalExitIf(err, r.backupFinished.Load(), console.AddCallbacks(fn))
watcher := r.session.WatchServices(prefix, revision, nil)
@ -128,7 +128,7 @@ func (r *Runner) CheckCompatible() bool {
}
func (r *Runner) checkSessionsWithPrefix(prefix string) error {
sessions, _, err := r.session.GetSessions(prefix)
sessions, _, err := r.session.GetSessions(r.ctx, prefix)
if err != nil {
return err
}
@ -139,7 +139,7 @@ func (r *Runner) checkSessionsWithPrefix(prefix string) error {
}
func (r *Runner) checkMySelf() error {
sessions, _, err := r.session.GetSessions(Role)
sessions, _, err := r.session.GetSessions(r.ctx, Role)
if err != nil {
return err
}

View File

@ -235,7 +235,7 @@ func (s *Server) Register() error {
}
func (s *Server) ServerExist(serverID int64) bool {
sessions, _, err := s.session.GetSessions(typeutil.DataNodeRole)
sessions, _, err := s.session.GetSessions(s.ctx, typeutil.DataNodeRole)
if err != nil {
log.Ctx(s.ctx).Warn("failed to get sessions", zap.Error(err))
return false
@ -547,7 +547,7 @@ func (s *Server) initServiceDiscovery() error {
}
s.indexEngineVersionManager = newIndexEngineVersionManager()
qnSessions, qnRevision, err := s.session.GetSessions(typeutil.QueryNodeRole)
qnSessions, qnRevision, err := s.session.GetSessions(s.ctx, typeutil.QueryNodeRole)
if err != nil {
log.Warn("DataCoord get QueryNode sessions failed", zap.Error(err))
return err

View File

@ -105,7 +105,7 @@ func (cm *ConnectionManager) AddDependency(roleName string) error {
}
cm.dependencies[roleName] = struct{}{}
msess, rev, err := cm.session.GetSessions(roleName)
msess, rev, err := cm.session.GetSessions(context.TODO(), roleName)
if err != nil {
log.Debug("ClientManager GetSessions failed", zap.String("roleName", roleName))
return err

View File

@ -50,10 +50,8 @@ type DataNodeClient struct {
// Client is the grpc client for DataNode
type Client struct {
grpcClient grpcclient.GrpcClient[DataNodeClient]
sess *sessionutil.Session
addr string
serverID int64
ctx context.Context
}
// NewClient creates a client for DataNode.
@ -61,7 +59,7 @@ func NewClient(ctx context.Context, addr string, serverID int64, encryption bool
if addr == "" {
return nil, errors.New("address is empty")
}
sess := sessionutil.NewSession(ctx)
sess := sessionutil.NewSession(context.Background())
if sess == nil {
err := errors.New("new session error, maybe can not connect to etcd")
log.Ctx(ctx).Debug("DataNodeClient New Etcd Session failed", zap.Error(err))
@ -72,9 +70,7 @@ func NewClient(ctx context.Context, addr string, serverID int64, encryption bool
client := &Client{
addr: addr,
grpcClient: grpcclient.NewClientBase[DataNodeClient](config, "milvus.proto.data.DataNode"),
sess: sess,
serverID: serverID,
ctx: ctx,
}
// node shall specify node id
client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.DataNodeRole, serverID))

View File

@ -68,7 +68,7 @@ type Client struct {
// etcdEndpoints are the address list for etcd end points
// timeout is default setting for each grpc call
func NewClient(ctx context.Context) (types.MixCoordClient, error) {
sess := sessionutil.NewSession(ctx)
sess := sessionutil.NewSession(context.Background())
if sess == nil {
err := errors.New("new session error, maybe can not connect to etcd")
log.Ctx(ctx).Debug("New MixCoord Client failed", zap.Error(err))
@ -110,7 +110,7 @@ func (c *Client) newGrpcClient(cc *grpc.ClientConn) MixCoordClient {
func (c *Client) getMixCoordAddr() (string, error) {
log := log.Ctx(c.ctx)
key := c.grpcClient.GetRole()
msess, _, err := c.sess.GetSessions(key)
msess, _, err := c.sess.GetSessions(c.ctx, key)
if err != nil {
log.Debug("MixCoordClient GetSessions failed", zap.Any("key", key))
return "", err
@ -135,7 +135,7 @@ func (c *Client) getMixCoordAddr() (string, error) {
// compatible with standalone mode upgrade from 2.5, shoule be removed in 3.0
func (c *Client) getCompatibleMixCoordAddr() (string, error) {
log := log.Ctx(c.ctx)
msess, _, err := c.sess.GetSessions(typeutil.RootCoordRole)
msess, _, err := c.sess.GetSessions(c.ctx, typeutil.RootCoordRole)
if err != nil {
log.Debug("mixCoordClient getSessions failed", zap.Any("key", typeutil.RootCoordRole), zap.Error(err))
return "", errors.New("find no available mixcoord, check mixcoord state")

View File

@ -45,7 +45,6 @@ var Params *paramtable.ComponentParam = paramtable.Get()
type Client struct {
grpcClient grpcclient.GrpcClient[proxypb.ProxyClient]
addr string
sess *sessionutil.Session
}
// NewClient creates a new client instance
@ -53,7 +52,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien
if addr == "" {
return nil, errors.New("address is empty")
}
sess := sessionutil.NewSession(ctx)
sess := sessionutil.NewSession(context.Background())
if sess == nil {
err := errors.New("new session error, maybe can not connect to etcd")
log.Ctx(ctx).Debug("Proxy client new session failed", zap.Error(err))
@ -63,7 +62,6 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.ProxyClien
client := &Client{
addr: addr,
grpcClient: grpcclient.NewClientBase[proxypb.ProxyClient](config, "milvus.proto.proxy.Proxy"),
sess: sess,
}
// node shall specify node id
client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.ProxyRole, nodeID))

View File

@ -45,9 +45,7 @@ var Params *paramtable.ComponentParam = paramtable.Get()
type Client struct {
grpcClient grpcclient.GrpcClient[querypb.QueryNodeClient]
addr string
sess *sessionutil.Session
nodeID int64
ctx context.Context
}
// NewClient creates a new QueryNode client.
@ -55,7 +53,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC
if addr == "" {
return nil, errors.New("addr is empty")
}
sess := sessionutil.NewSession(ctx)
sess := sessionutil.NewSession(context.Background())
if sess == nil {
err := errors.New("new session error, maybe can not connect to etcd")
log.Ctx(ctx).Debug("QueryNodeClient NewClient failed", zap.Error(err))
@ -65,9 +63,7 @@ func NewClient(ctx context.Context, addr string, nodeID int64) (types.QueryNodeC
client := &Client{
addr: addr,
grpcClient: grpcclient.NewClientBase[querypb.QueryNodeClient](config, "milvus.proto.query.QueryNode"),
sess: sess,
nodeID: nodeID,
ctx: ctx,
}
// node shall specify node id
client.grpcClient.SetRole(fmt.Sprintf("%s-%d", typeutil.QueryNodeRole, nodeID))

View File

@ -166,7 +166,7 @@ func (s *Server) SetSession(session sessionutil.SessionInterface) error {
}
func (s *Server) ServerExist(serverID int64) bool {
sessions, _, err := s.session.GetSessions(typeutil.QueryNodeRole)
sessions, _, err := s.session.GetSessions(s.ctx, typeutil.QueryNodeRole)
if err != nil {
log.Ctx(s.ctx).Warn("failed to get sessions", zap.Error(err))
return false
@ -486,7 +486,7 @@ func (s *Server) Start() error {
func (s *Server) startQueryCoord() error {
log.Ctx(s.ctx).Info("start watcher...")
sessions, revision, err := s.session.GetSessions(typeutil.QueryNodeRole)
sessions, revision, err := s.session.GetSessions(s.ctx, typeutil.QueryNodeRole)
if err != nil {
return err
}

View File

@ -327,7 +327,7 @@ func (node *QueryNode) Init() error {
return NewLocalWorker(node), nil
}
sessions, _, err := node.session.GetSessions(typeutil.QueryNodeRole)
sessions, _, err := node.session.GetSessions(node.ctx, typeutil.QueryNodeRole)
if err != nil {
return nil, err
}

View File

@ -379,7 +379,7 @@ func (c *ClientBase[T]) verifySession(ctx context.Context) error {
}
c.lastSessionCheck.Store(time.Now())
if c.sess != nil {
sessions, _, getSessionErr := c.sess.GetSessions(c.GetRole())
sessions, _, getSessionErr := c.sess.GetSessions(ctx, c.GetRole())
if getSessionErr != nil {
// Only log but not handle this error as it is an auxiliary logic
log.Warn("fail to get session", zap.Error(getSessionErr))

View File

@ -107,7 +107,7 @@ func TestClientBase_NodeSessionNotExist(t *testing.T) {
})
base.role = typeutil.QueryNodeRole
mockSession := sessionutil.NewMockSession(t)
mockSession.EXPECT().GetSessions(mock.Anything).Return(nil, 0, nil)
mockSession.EXPECT().GetSessions(mock.Anything, mock.Anything).Return(nil, 0, nil)
base.sess = mockSession
base.grpcClientMtx.Lock()
base.grpcClient = nil
@ -551,7 +551,7 @@ func TestVerifySession(t *testing.T) {
base := ClientBase[*mockClient]{}
mockSession := sessionutil.NewMockSession(t)
expectedErr := errors.New("mocked")
mockSession.EXPECT().GetSessions(mock.Anything).Return(nil, 0, expectedErr)
mockSession.EXPECT().GetSessions(mock.Anything, mock.Anything).Return(nil, 0, expectedErr)
base.sess = mockSession
ctx := context.Background()
@ -562,7 +562,7 @@ func TestVerifySession(t *testing.T) {
base.NodeID = *atomic.NewInt64(1)
base.role = typeutil.RootCoordRole
mockSession2 := sessionutil.NewMockSession(t)
mockSession2.EXPECT().GetSessions(mock.Anything).Return(
mockSession2.EXPECT().GetSessions(mock.Anything, mock.Anything).Return(
map[string]*sessionutil.Session{
typeutil.RootCoordRole: {
SessionRaw: sessionutil.SessionRaw{

View File

@ -3,10 +3,10 @@
package sessionutil
import (
context "context"
semver "github.com/blang/semver/v4"
mock "github.com/stretchr/testify/mock"
time "time"
)
// MockSession is an autogenerated mock type for the SessionInterface type
@ -157,9 +157,9 @@ func (_c *MockSession_GetServerID_Call) RunAndReturn(run func() int64) *MockSess
return _c
}
// GetSessions provides a mock function with given fields: prefix
func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, error) {
ret := _m.Called(prefix)
// GetSessions provides a mock function with given fields: ctx, prefix
func (_m *MockSession) GetSessions(ctx context.Context, prefix string) (map[string]*Session, int64, error) {
ret := _m.Called(ctx, prefix)
if len(ret) == 0 {
panic("no return value specified for GetSessions")
@ -168,25 +168,25 @@ func (_m *MockSession) GetSessions(prefix string) (map[string]*Session, int64, e
var r0 map[string]*Session
var r1 int64
var r2 error
if rf, ok := ret.Get(0).(func(string) (map[string]*Session, int64, error)); ok {
return rf(prefix)
if rf, ok := ret.Get(0).(func(context.Context, string) (map[string]*Session, int64, error)); ok {
return rf(ctx, prefix)
}
if rf, ok := ret.Get(0).(func(string) map[string]*Session); ok {
r0 = rf(prefix)
if rf, ok := ret.Get(0).(func(context.Context, string) map[string]*Session); ok {
r0 = rf(ctx, prefix)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(map[string]*Session)
}
}
if rf, ok := ret.Get(1).(func(string) int64); ok {
r1 = rf(prefix)
if rf, ok := ret.Get(1).(func(context.Context, string) int64); ok {
r1 = rf(ctx, prefix)
} else {
r1 = ret.Get(1).(int64)
}
if rf, ok := ret.Get(2).(func(string) error); ok {
r2 = rf(prefix)
if rf, ok := ret.Get(2).(func(context.Context, string) error); ok {
r2 = rf(ctx, prefix)
} else {
r2 = ret.Error(2)
}
@ -200,14 +200,15 @@ type MockSession_GetSessions_Call struct {
}
// GetSessions is a helper method to define mock.On call
// - ctx context.Context
// - prefix string
func (_e *MockSession_Expecter) GetSessions(prefix interface{}) *MockSession_GetSessions_Call {
return &MockSession_GetSessions_Call{Call: _e.mock.On("GetSessions", prefix)}
func (_e *MockSession_Expecter) GetSessions(ctx interface{}, prefix interface{}) *MockSession_GetSessions_Call {
return &MockSession_GetSessions_Call{Call: _e.mock.On("GetSessions", ctx, prefix)}
}
func (_c *MockSession_GetSessions_Call) Run(run func(prefix string)) *MockSession_GetSessions_Call {
func (_c *MockSession_GetSessions_Call) Run(run func(ctx context.Context, prefix string)) *MockSession_GetSessions_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(string))
run(args[0].(context.Context), args[1].(string))
})
return _c
}
@ -217,7 +218,7 @@ func (_c *MockSession_GetSessions_Call) Return(_a0 map[string]*Session, _a1 int6
return _c
}
func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(string) (map[string]*Session, int64, error)) *MockSession_GetSessions_Call {
func (_c *MockSession_GetSessions_Call) RunAndReturn(run func(context.Context, string) (map[string]*Session, int64, error)) *MockSession_GetSessions_Call {
_c.Call.Return(run)
return _c
}
@ -594,39 +595,6 @@ func (_c *MockSession_Registered_Call) RunAndReturn(run func() bool) *MockSessio
return _c
}
// Revoke provides a mock function with given fields: timeout
func (_m *MockSession) Revoke(timeout time.Duration) {
_m.Called(timeout)
}
// MockSession_Revoke_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Revoke'
type MockSession_Revoke_Call struct {
*mock.Call
}
// Revoke is a helper method to define mock.On call
// - timeout time.Duration
func (_e *MockSession_Expecter) Revoke(timeout interface{}) *MockSession_Revoke_Call {
return &MockSession_Revoke_Call{Call: _e.mock.On("Revoke", timeout)}
}
func (_c *MockSession_Revoke_Call) Run(run func(timeout time.Duration)) *MockSession_Revoke_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(time.Duration))
})
return _c
}
func (_c *MockSession_Revoke_Call) Return() *MockSession_Revoke_Call {
_c.Call.Return()
return _c
}
func (_c *MockSession_Revoke_Call) RunAndReturn(run func(time.Duration)) *MockSession_Revoke_Call {
_c.Run(run)
return _c
}
// SetDisconnected provides a mock function with given fields: b
func (_m *MockSession) SetDisconnected(b bool) {
_m.Called(b)

View File

@ -16,6 +16,8 @@
package sessionutil
import (
"context"
"github.com/blang/semver/v4"
)
@ -27,7 +29,7 @@ type SessionInterface interface {
String() string
Register()
GetSessions(prefix string) (map[string]*Session, int64, error)
GetSessions(ctx context.Context, prefix string) (map[string]*Session, int64, error)
GetSessionsWithVersionRange(prefix string, r semver.Range) (map[string]*Session, int64, error)
GoingStop() error

View File

@ -612,10 +612,10 @@ func (s *Session) startKeepAliveLoop() {
// GetSessions will get all sessions registered in etcd.
// Revision is returned for WatchServices to prevent key events from being missed.
func (s *Session) GetSessions(prefix string) (map[string]*Session, int64, error) {
func (s *Session) GetSessions(ctx context.Context, prefix string) (map[string]*Session, int64, error) {
res := make(map[string]*Session)
key := path.Join(s.metaRoot, DefaultServiceRoot, prefix)
resp, err := s.etcdCli.Get(s.ctx, key, clientv3.WithPrefix(),
resp, err := s.etcdCli.Get(ctx, key, clientv3.WithPrefix(),
clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend))
if err != nil {
return nil, 0, err
@ -868,7 +868,7 @@ func (w *sessionWatcher) handleWatchErr(err error) error {
return err
}
sessions, revision, err := w.s.GetSessions(w.prefix)
sessions, revision, err := w.s.GetSessions(w.s.ctx, w.prefix)
if err != nil {
log.Warn("GetSession before rewatch failed", zap.String("prefix", w.prefix), zap.Error(err))
w.closeEventCh()
@ -963,7 +963,7 @@ func (s *Session) ProcessActiveStandBy(activateFunc func() error) error {
registerActiveFn := func() (bool, int64, error) {
for _, role := range oldRoles {
sessions, _, err := s.GetSessions(role)
sessions, _, err := s.GetSessions(s.ctx, role)
if err != nil {
log.Debug("failed to get old sessions", zap.String("role", role), zap.Error(err))
continue

View File

@ -88,7 +88,7 @@ func TestInit(t *testing.T) {
assert.NotEqual(t, int64(0), s.LeaseID)
assert.NotEqual(t, int64(0), s.ServerID)
s.Register()
sessions, _, err := s.GetSessions("inittest")
sessions, _, err := s.GetSessions(ctx, "inittest")
assert.NoError(t, err)
assert.Contains(t, sessions, "inittest-"+strconv.FormatInt(s.ServerID, 10))
}
@ -111,7 +111,7 @@ func TestInitNoArgs(t *testing.T) {
assert.NotEqual(t, int64(0), s.LeaseID)
assert.NotEqual(t, int64(0), s.ServerID)
s.Register()
sessions, _, err := s.GetSessions("inittest")
sessions, _, err := s.GetSessions(ctx, "inittest")
assert.NoError(t, err)
assert.Contains(t, sessions, "inittest-"+strconv.FormatInt(s.ServerID, 10))
}
@ -131,7 +131,7 @@ func TestUpdateSessions(t *testing.T) {
s := NewSessionWithEtcd(ctx, metaRoot, etcdCli, WithResueNodeID(false))
sessions, rev, err := s.GetSessions("test")
sessions, rev, err := s.GetSessions(ctx, "test")
assert.NoError(t, err)
assert.Equal(t, len(sessions), 0)
watcher := s.WatchServices("test", rev, nil)
@ -155,15 +155,15 @@ func TestUpdateSessions(t *testing.T) {
wg.Wait()
assert.Eventually(t, func() bool {
sessions, _, _ := s.GetSessions("test")
sessions, _, _ := s.GetSessions(ctx, "test")
return len(sessions) == 10
}, 10*time.Second, 100*time.Millisecond)
notExistSessions, _, _ := s.GetSessions("testt")
notExistSessions, _, _ := s.GetSessions(ctx, "testt")
assert.Equal(t, len(notExistSessions), 0)
etcdKV.RemoveWithPrefix(ctx, metaRoot)
assert.Eventually(t, func() bool {
sessions, _, _ := s.GetSessions("test")
sessions, _, _ := s.GetSessions(ctx, "test")
return len(sessions) == 0
}, 10*time.Second, 100*time.Millisecond)