Add server id validation interceptor (#26395)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-08-17 20:20:20 +08:00 committed by GitHub
parent 2539a19885
commit 63b86b32a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1397 additions and 355 deletions

View File

@ -367,7 +367,7 @@ func (suite *ClusterSuite) TestUnregister() {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
var mockSessionCreator = func(ctx context.Context, addr string) (types.DataNode, error) {
var mockSessionCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(1, nil)
}
sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator))
@ -414,7 +414,7 @@ func TestWatchIfNeeded(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
var mockSessionCreator = func(ctx context.Context, addr string) (types.DataNode, error) {
var mockSessionCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(1, nil)
}
sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator))
@ -629,7 +629,7 @@ func TestCluster_ReCollectSegmentStats(t *testing.T) {
t.Run("recollect succeed", func(t *testing.T) {
ctx, cancel := context.WithCancel(context.TODO())
defer cancel()
var mockSessionCreator = func(ctx context.Context, addr string) (types.DataNode, error) {
var mockSessionCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(1, nil)
}
sessionManager := NewSessionManager(withSessionCreator(mockSessionCreator))

View File

@ -86,7 +86,7 @@ func (nm *IndexNodeManager) AddNode(nodeID UniqueID, address string) error {
err error
)
nodeClient, err = nm.indexNodeCreator(context.TODO(), address)
nodeClient, err = nm.indexNodeCreator(context.TODO(), address, nodeID)
if err != nil {
log.Error("create IndexNode client fail", zap.Error(err))
return err

View File

@ -68,7 +68,7 @@ func TestGetDataNodeMetrics(t *testing.T) {
_, err = svr.getDataNodeMetrics(ctx, req, NewSession(&NodeInfo{}, nil))
assert.Error(t, err)
creator := func(ctx context.Context, addr string) (types.DataNode, error) {
creator := func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(100, nil)
}
@ -80,8 +80,8 @@ func TestGetDataNodeMetrics(t *testing.T) {
assert.Equal(t, metricsinfo.ConstructComponentName(typeutil.DataNodeRole, 100), info.BaseComponentInfos.Name)
getMockFailedClientCreator := func(mockFunc func() (*milvuspb.GetMetricsResponse, error)) dataNodeCreatorFunc {
return func(ctx context.Context, addr string) (types.DataNode, error) {
cli, err := creator(ctx, addr)
return func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
cli, err := creator(ctx, addr, nodeID)
assert.NoError(t, err)
return &mockMetricDataNodeClient{DataNode: cli, mock: mockFunc}, nil
}

View File

@ -79,9 +79,9 @@ type (
Timestamp = typeutil.Timestamp
)
type dataNodeCreatorFunc func(ctx context.Context, addr string) (types.DataNode, error)
type dataNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error)
type indexNodeCreatorFunc func(ctx context.Context, addr string) (types.IndexNode, error)
type indexNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error)
type rootCoordCreatorFunc func(ctx context.Context, metaRootPath string, etcdClient *clientv3.Client) (types.RootCoord, error)
@ -220,12 +220,12 @@ func CreateServer(ctx context.Context, factory dependency.Factory, opts ...Optio
return s
}
func defaultDataNodeCreatorFunc(ctx context.Context, addr string) (types.DataNode, error) {
return datanodeclient.NewClient(ctx, addr)
func defaultDataNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return datanodeclient.NewClient(ctx, addr, nodeID)
}
func defaultIndexNodeCreatorFunc(ctx context.Context, addr string) (types.IndexNode, error) {
return indexnodeclient.NewClient(context.TODO(), addr, Params.DataCoordCfg.WithCredential.GetAsBool())
func defaultIndexNodeCreatorFunc(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error) {
return indexnodeclient.NewClient(ctx, addr, nodeID, Params.DataCoordCfg.WithCredential.GetAsBool())
}
func defaultRootCoordCreatorFunc(ctx context.Context, metaRootPath string, client *clientv3.Client) (types.RootCoord, error) {
@ -427,11 +427,11 @@ func (s *Server) SetRootCoord(rootCoord types.RootCoord) {
s.rootCoordClient = rootCoord
}
func (s *Server) SetDataNodeCreator(f func(context.Context, string) (types.DataNode, error)) {
func (s *Server) SetDataNodeCreator(f func(context.Context, string, int64) (types.DataNode, error)) {
s.dataNodeCreator = f
}
func (s *Server) SetIndexNodeCreator(f func(context.Context, string) (types.IndexNode, error)) {
func (s *Server) SetIndexNodeCreator(f func(context.Context, string, int64) (types.IndexNode, error)) {
s.indexNodeCreator = f
}

View File

@ -3221,7 +3221,7 @@ func TestOptions(t *testing.T) {
t.Run("WithDataNodeCreator", func(t *testing.T) {
var target int64
var val = rand.Int63()
opt := WithDataNodeCreator(func(context.Context, string) (types.DataNode, error) {
opt := WithDataNodeCreator(func(context.Context, string, int64) (types.DataNode, error) {
target = val
return nil, nil
})
@ -3230,7 +3230,7 @@ func TestOptions(t *testing.T) {
factory := dependency.NewDefaultFactory(true)
svr := CreateServer(context.TODO(), factory, opt)
dn, err := svr.dataNodeCreator(context.Background(), "")
dn, err := svr.dataNodeCreator(context.Background(), "", 1)
assert.Nil(t, dn)
assert.NoError(t, err)
assert.Equal(t, target, val)
@ -3945,7 +3945,7 @@ func newTestServer(t *testing.T, receiveCh chan any, opts ...Option) *Server {
svr := CreateServer(context.TODO(), factory)
svr.SetEtcdClient(etcdCli)
svr.dataNodeCreator = func(ctx context.Context, addr string) (types.DataNode, error) {
svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(0, receiveCh)
}
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
@ -3997,7 +3997,7 @@ func newTestServerWithMeta(t *testing.T, receiveCh chan any, meta *meta, opts ..
svr := CreateServer(context.TODO(), factory, opts...)
svr.SetEtcdClient(etcdCli)
svr.dataNodeCreator = func(ctx context.Context, addr string) (types.DataNode, error) {
svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(0, receiveCh)
}
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
@ -4052,7 +4052,7 @@ func newTestServer2(t *testing.T, receiveCh chan any, opts ...Option) *Server {
svr := CreateServer(context.TODO(), factory, opts...)
svr.SetEtcdClient(etcdCli)
svr.dataNodeCreator = func(ctx context.Context, addr string) (types.DataNode, error) {
svr.dataNodeCreator = func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(0, receiveCh)
}
svr.rootCoordClientCreator = func(ctx context.Context, metaRootPath string, etcdCli *clientv3.Client) (types.RootCoord, error) {
@ -4097,7 +4097,7 @@ func Test_CheckHealth(t *testing.T) {
data map[int64]*Session
}{data: map[int64]*Session{1: {
client: healthClient,
clientCreator: func(ctx context.Context, addr string) (types.DataNode, error) {
clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return healthClient, nil
},
}}}
@ -4122,7 +4122,7 @@ func Test_CheckHealth(t *testing.T) {
data map[int64]*Session
}{data: map[int64]*Session{1: {
client: unhealthClient,
clientCreator: func(ctx context.Context, addr string) (types.DataNode, error) {
clientCreator: func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return unhealthClient, nil
},
}}}
@ -4244,10 +4244,10 @@ func testDataCoordBase(t *testing.T, opts ...Option) *Server {
svr := CreateServer(ctx, factory, opts...)
svr.SetEtcdClient(etcdCli)
svr.SetDataNodeCreator(func(ctx context.Context, addr string) (types.DataNode, error) {
svr.SetDataNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return newMockDataNodeClient(0, nil)
})
svr.SetIndexNodeCreator(func(ctx context.Context, addr string) (types.IndexNode, error) {
svr.SetIndexNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error) {
return indexnode.NewMockIndexNodeComponent(ctx)
})
svr.SetRootCoord(newMockRootCoordService())

View File

@ -73,7 +73,7 @@ func (n *Session) GetOrCreateClient(ctx context.Context) (types.DataNode, error)
}
func (n *Session) initClient(ctx context.Context) (err error) {
if n.client, err = n.clientCreator(ctx, n.info.Address); err != nil {
if n.client, err = n.clientCreator(ctx, n.info.Address, n.info.NodeID); err != nil {
return
}
if err = n.client.Init(); err != nil {

View File

@ -59,8 +59,8 @@ func withSessionCreator(creator dataNodeCreatorFunc) SessionOpt {
}
func defaultSessionCreator() dataNodeCreatorFunc {
return func(ctx context.Context, addr string) (types.DataNode, error) {
return grpcdatanodeclient.NewClient(ctx, addr)
return func(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
return grpcdatanodeclient.NewClient(ctx, addr, nodeID)
}
}

View File

@ -154,11 +154,13 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
)))
indexpb.RegisterIndexCoordServer(s.grpcServer, s)
datapb.RegisterDataCoordServer(s.grpcServer, s)

View File

@ -107,10 +107,10 @@ func (m *MockDataCoord) SetEtcdClient(etcdClient *clientv3.Client) {
func (m *MockDataCoord) SetRootCoord(rootCoord types.RootCoord) {
}
func (m *MockDataCoord) SetDataNodeCreator(func(context.Context, string) (types.DataNode, error)) {
func (m *MockDataCoord) SetDataNodeCreator(func(context.Context, string, int64) (types.DataNode, error)) {
}
func (m *MockDataCoord) SetIndexNodeCreator(func(context.Context, string) (types.IndexNode, error)) {
func (m *MockDataCoord) SetIndexNodeCreator(func(context.Context, string, int64) (types.IndexNode, error)) {
}
func (m *MockDataCoord) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) {

View File

@ -42,7 +42,7 @@ type Client struct {
}
// NewClient creates a client for DataNode.
func NewClient(ctx context.Context, addr string) (*Client, error) {
func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) {
if addr == "" {
return nil, fmt.Errorf("address is empty")
}
@ -66,6 +66,7 @@ func NewClient(ctx context.Context, addr string) (*Client, error) {
client.grpcClient.SetRole(typeutil.DataNodeRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetNodeID(nodeID)
return client, nil
}

View File

@ -33,11 +33,11 @@ import (
func Test_NewClient(t *testing.T) {
proxy.Params.Init()
ctx := context.Background()
client, err := NewClient(ctx, "")
client, err := NewClient(ctx, "", 1)
assert.Nil(t, client)
assert.Error(t, err)
client, err = NewClient(ctx, "test")
client, err = NewClient(ctx, "test", 2)
assert.NoError(t, err)
assert.NotNil(t, client)

View File

@ -138,11 +138,13 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
)))
datapb.RegisterDataNodeServer(s.grpcServer, s)

View File

@ -43,7 +43,7 @@ type Client struct {
}
// NewClient creates a new IndexNode client.
func NewClient(ctx context.Context, addr string, encryption bool) (*Client, error) {
func NewClient(ctx context.Context, addr string, nodeID int64, encryption bool) (*Client, error) {
if addr == "" {
return nil, fmt.Errorf("address is empty")
}
@ -67,6 +67,7 @@ func NewClient(ctx context.Context, addr string, encryption bool) (*Client, erro
client.grpcClient.SetRole(typeutil.IndexNodeRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetNodeID(nodeID)
if encryption {
client.grpcClient.EnableEncryption()
}

View File

@ -40,11 +40,11 @@ import (
func Test_NewClient(t *testing.T) {
paramtable.Init()
ctx := context.Background()
client, err := NewClient(ctx, "", false)
client, err := NewClient(ctx, "", 1, false)
assert.Nil(t, client)
assert.Error(t, err)
client, err = NewClient(ctx, "test", false)
client, err = NewClient(ctx, "test", 2, false)
assert.NoError(t, err)
assert.NotNil(t, client)
@ -148,7 +148,7 @@ func TestIndexNodeClient(t *testing.T) {
err = ins.Run()
assert.NoError(t, err)
inc, err := NewClient(ctx, "localhost:21121", false)
inc, err := NewClient(ctx, "localhost:21121", paramtable.GetNodeID(), false)
assert.NoError(t, err)
assert.NotNil(t, inc)

View File

@ -110,11 +110,13 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
)))
indexpb.RegisterIndexNodeServer(s.grpcServer, s)
go funcutil.CheckGrpcReady(ctx, s.grpcErrChan)

View File

@ -42,7 +42,7 @@ type Client struct {
}
// NewClient creates a new client instance
func NewClient(ctx context.Context, addr string) (*Client, error) {
func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) {
if addr == "" {
return nil, fmt.Errorf("address is empty")
}
@ -66,6 +66,7 @@ func NewClient(ctx context.Context, addr string) (*Client, error) {
client.grpcClient.SetRole(typeutil.ProxyRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetNodeID(nodeID)
return client, nil
}

View File

@ -34,11 +34,11 @@ func Test_NewClient(t *testing.T) {
proxy.Params.Init()
ctx := context.Background()
client, err := NewClient(ctx, "")
client, err := NewClient(ctx, "", 1)
assert.Nil(t, client)
assert.Error(t, err)
client, err = NewClient(ctx, "test")
client, err = NewClient(ctx, "test", 2)
assert.NoError(t, err)
assert.NotNil(t, client)

View File

@ -314,9 +314,12 @@ func (s *Server) startInternalGrpc(grpcPort int, errChan chan error) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(interceptor.ClusterValidationStreamServerInterceptor()),
)
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
)))
proxypb.RegisterProxyServer(s.grpcInternalServer, s)
grpc_health_v1.RegisterHealthServer(s.grpcInternalServer, s)
errChan <- nil

View File

@ -740,7 +740,7 @@ func (m *MockProxy) SetQueryCoordClient(queryCoord types.QueryCoord) {
}
func (m *MockProxy) SetQueryNodeCreator(func(ctx context.Context, addr string) (types.QueryNode, error)) {
func (m *MockProxy) SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)) {
}

View File

@ -228,11 +228,13 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
)))
querypb.RegisterQueryCoordServer(s.grpcServer, s)

View File

@ -41,7 +41,7 @@ type Client struct {
}
// NewClient creates a new QueryNode client.
func NewClient(ctx context.Context, addr string) (*Client, error) {
func NewClient(ctx context.Context, addr string, nodeID int64) (*Client, error) {
if addr == "" {
return nil, fmt.Errorf("addr is empty")
}
@ -65,6 +65,7 @@ func NewClient(ctx context.Context, addr string) (*Client, error) {
client.grpcClient.SetRole(typeutil.QueryNodeRole)
client.grpcClient.SetGetAddrFunc(client.getAddr)
client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient)
client.grpcClient.SetNodeID(nodeID)
return client, nil
}

View File

@ -33,11 +33,11 @@ func Test_NewClient(t *testing.T) {
paramtable.Init()
ctx := context.Background()
client, err := NewClient(ctx, "")
client, err := NewClient(ctx, "", 1)
assert.Nil(t, client)
assert.Error(t, err)
client, err = NewClient(ctx, "test")
client, err = NewClient(ctx, "test", 2)
assert.NoError(t, err)
assert.NotNil(t, client)

View File

@ -183,11 +183,13 @@ func (s *Server) startGrpcLoop(grpcPort int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
)))
querypb.RegisterQueryNodeServer(s.grpcServer, s)

View File

@ -255,11 +255,13 @@ func (s *Server) startGrpcLoop(port int) {
otelgrpc.UnaryServerInterceptor(opts...),
logutil.UnaryTraceLoggerInterceptor,
interceptor.ClusterValidationUnaryServerInterceptor(),
interceptor.ServerIDValidationUnaryServerInterceptor(),
)),
grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(
otelgrpc.StreamServerInterceptor(opts...),
logutil.StreamTraceLoggerInterceptor,
interceptor.ClusterValidationStreamServerInterceptor(),
interceptor.ServerIDValidationStreamServerInterceptor(),
)))
rootcoordpb.RegisterRootCoordServer(s.grpcServer, s)

View File

@ -84,7 +84,7 @@ func (m *mockCore) SetQueryCoord(types.QueryCoord) error {
return nil
}
func (m *mockCore) SetProxyCreator(func(ctx context.Context, addr string) (types.Proxy, error)) {
func (m *mockCore) SetProxyCreator(func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error)) {
}
func (m *mockCore) Register() error {

View File

@ -2236,7 +2236,7 @@ func (_c *MockDataCoord_SetAddress_Call) RunAndReturn(run func(string)) *MockDat
}
// SetDataNodeCreator provides a mock function with given fields: _a0
func (_m *MockDataCoord) SetDataNodeCreator(_a0 func(context.Context, string) (types.DataNode, error)) {
func (_m *MockDataCoord) SetDataNodeCreator(_a0 func(context.Context, string, int64) (types.DataNode, error)) {
_m.Called(_a0)
}
@ -2246,14 +2246,14 @@ type MockDataCoord_SetDataNodeCreator_Call struct {
}
// SetDataNodeCreator is a helper method to define mock.On call
// - _a0 func(context.Context , string)(types.DataNode , error)
// - _a0 func(context.Context , string , int64)(types.DataNode , error)
func (_e *MockDataCoord_Expecter) SetDataNodeCreator(_a0 interface{}) *MockDataCoord_SetDataNodeCreator_Call {
return &MockDataCoord_SetDataNodeCreator_Call{Call: _e.mock.On("SetDataNodeCreator", _a0)}
}
func (_c *MockDataCoord_SetDataNodeCreator_Call) Run(run func(_a0 func(context.Context, string) (types.DataNode, error))) *MockDataCoord_SetDataNodeCreator_Call {
func (_c *MockDataCoord_SetDataNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.DataNode, error))) *MockDataCoord_SetDataNodeCreator_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(func(context.Context, string) (types.DataNode, error)))
run(args[0].(func(context.Context, string, int64) (types.DataNode, error)))
})
return _c
}
@ -2263,7 +2263,7 @@ func (_c *MockDataCoord_SetDataNodeCreator_Call) Return() *MockDataCoord_SetData
return _c
}
func (_c *MockDataCoord_SetDataNodeCreator_Call) RunAndReturn(run func(func(context.Context, string) (types.DataNode, error))) *MockDataCoord_SetDataNodeCreator_Call {
func (_c *MockDataCoord_SetDataNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.DataNode, error))) *MockDataCoord_SetDataNodeCreator_Call {
_c.Call.Return(run)
return _c
}
@ -2302,7 +2302,7 @@ func (_c *MockDataCoord_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Clie
}
// SetIndexNodeCreator provides a mock function with given fields: _a0
func (_m *MockDataCoord) SetIndexNodeCreator(_a0 func(context.Context, string) (types.IndexNode, error)) {
func (_m *MockDataCoord) SetIndexNodeCreator(_a0 func(context.Context, string, int64) (types.IndexNode, error)) {
_m.Called(_a0)
}
@ -2312,14 +2312,14 @@ type MockDataCoord_SetIndexNodeCreator_Call struct {
}
// SetIndexNodeCreator is a helper method to define mock.On call
// - _a0 func(context.Context , string)(types.IndexNode , error)
// - _a0 func(context.Context , string , int64)(types.IndexNode , error)
func (_e *MockDataCoord_Expecter) SetIndexNodeCreator(_a0 interface{}) *MockDataCoord_SetIndexNodeCreator_Call {
return &MockDataCoord_SetIndexNodeCreator_Call{Call: _e.mock.On("SetIndexNodeCreator", _a0)}
}
func (_c *MockDataCoord_SetIndexNodeCreator_Call) Run(run func(_a0 func(context.Context, string) (types.IndexNode, error))) *MockDataCoord_SetIndexNodeCreator_Call {
func (_c *MockDataCoord_SetIndexNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.IndexNode, error))) *MockDataCoord_SetIndexNodeCreator_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(func(context.Context, string) (types.IndexNode, error)))
run(args[0].(func(context.Context, string, int64) (types.IndexNode, error)))
})
return _c
}
@ -2329,7 +2329,7 @@ func (_c *MockDataCoord_SetIndexNodeCreator_Call) Return() *MockDataCoord_SetInd
return _c
}
func (_c *MockDataCoord_SetIndexNodeCreator_Call) RunAndReturn(run func(func(context.Context, string) (types.IndexNode, error))) *MockDataCoord_SetIndexNodeCreator_Call {
func (_c *MockDataCoord_SetIndexNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.IndexNode, error))) *MockDataCoord_SetIndexNodeCreator_Call {
_c.Call.Return(run)
return _c
}

File diff suppressed because it is too large Load Diff

View File

@ -1210,7 +1210,7 @@ func (_c *MockQueryCoord_SetEtcdClient_Call) RunAndReturn(run func(*clientv3.Cli
}
// SetQueryNodeCreator provides a mock function with given fields: _a0
func (_m *MockQueryCoord) SetQueryNodeCreator(_a0 func(context.Context, string) (types.QueryNode, error)) {
func (_m *MockQueryCoord) SetQueryNodeCreator(_a0 func(context.Context, string, int64) (types.QueryNode, error)) {
_m.Called(_a0)
}
@ -1220,14 +1220,14 @@ type MockQueryCoord_SetQueryNodeCreator_Call struct {
}
// SetQueryNodeCreator is a helper method to define mock.On call
// - _a0 func(context.Context , string)(types.QueryNode , error)
// - _a0 func(context.Context , string , int64)(types.QueryNode , error)
func (_e *MockQueryCoord_Expecter) SetQueryNodeCreator(_a0 interface{}) *MockQueryCoord_SetQueryNodeCreator_Call {
return &MockQueryCoord_SetQueryNodeCreator_Call{Call: _e.mock.On("SetQueryNodeCreator", _a0)}
}
func (_c *MockQueryCoord_SetQueryNodeCreator_Call) Run(run func(_a0 func(context.Context, string) (types.QueryNode, error))) *MockQueryCoord_SetQueryNodeCreator_Call {
func (_c *MockQueryCoord_SetQueryNodeCreator_Call) Run(run func(_a0 func(context.Context, string, int64) (types.QueryNode, error))) *MockQueryCoord_SetQueryNodeCreator_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(func(context.Context, string) (types.QueryNode, error)))
run(args[0].(func(context.Context, string, int64) (types.QueryNode, error)))
})
return _c
}
@ -1237,7 +1237,7 @@ func (_c *MockQueryCoord_SetQueryNodeCreator_Call) Return() *MockQueryCoord_SetQ
return _c
}
func (_c *MockQueryCoord_SetQueryNodeCreator_Call) RunAndReturn(run func(func(context.Context, string) (types.QueryNode, error))) *MockQueryCoord_SetQueryNodeCreator_Call {
func (_c *MockQueryCoord_SetQueryNodeCreator_Call) RunAndReturn(run func(func(context.Context, string, int64) (types.QueryNode, error))) *MockQueryCoord_SetQueryNodeCreator_Call {
_c.Call.Return(run)
return _c
}

View File

@ -484,7 +484,7 @@ func (node *Proxy) SetQueryCoordClient(cli types.QueryCoord) {
node.queryCoord = cli
}
func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string) (types.QueryNode, error)) {
func (node *Proxy) SetQueryNodeCreator(f func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)) {
node.shardMgr.SetClientCreatorFunc(f)
}

View File

@ -52,7 +52,7 @@ func TestProxyRpcLimit(t *testing.T) {
go testServer.startGrpc(ctx, &wg, &p)
assert.NoError(t, testServer.waitForGrpcReady())
defer testServer.grpcServer.Stop()
client, err := grpcproxyclient.NewClient(ctx, "localhost:"+p.Port.GetValue())
client, err := grpcproxyclient.NewClient(ctx, "localhost:"+p.Port.GetValue(), 1)
assert.NoError(t, err)
proxy.stateCode.Store(commonpb.StateCode_Healthy)

View File

@ -11,7 +11,7 @@ import (
"github.com/milvus-io/milvus/internal/types"
)
type queryNodeCreatorFunc func(ctx context.Context, addr string) (types.QueryNode, error)
type queryNodeCreatorFunc func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)
type nodeInfo struct {
nodeID UniqueID
@ -114,8 +114,8 @@ func withShardClientCreator(creator queryNodeCreatorFunc) shardClientMgrOpt {
return func(s shardClientMgr) { s.SetClientCreatorFunc(creator) }
}
func defaultQueryNodeClientCreator(ctx context.Context, addr string) (types.QueryNode, error) {
return qnClient.NewClient(ctx, addr)
func defaultQueryNodeClientCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) {
return qnClient.NewClient(ctx, addr, nodeID)
}
// NewShardClientMgr creates a new shardClientMgr
@ -178,7 +178,7 @@ func (c *shardClientMgrImpl) UpdateShardLeaders(oldLeaders map[string][]nodeInfo
if c.clientCreator == nil {
return fmt.Errorf("clientCreator function is nil")
}
shardClient, err := c.clientCreator(context.Background(), node.address)
shardClient, err := c.clientCreator(context.Background(), node.address, node.nodeID)
if err != nil {
return err
}

View File

@ -31,7 +31,7 @@ func TestShardClientMgr_UpdateShardLeaders_CreatorNil(t *testing.T) {
}
func TestShardClientMgr_UpdateShardLeaders_Empty(t *testing.T) {
mockCreator := func(ctx context.Context, addr string) (types.QueryNode, error) {
mockCreator := func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) {
return &mock.QueryNodeClient{}, nil
}
mgr := newShardClientMgr(withShardClientCreator(mockCreator))

View File

@ -569,7 +569,7 @@ func (s *Server) SetDataCoord(dataCoord types.DataCoord) error {
return nil
}
func (s *Server) SetQueryNodeCreator(f func(ctx context.Context, addr string) (types.QueryNode, error)) {
func (s *Server) SetQueryNodeCreator(f func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)) {
s.queryNodeCreator = f
}

View File

@ -72,10 +72,10 @@ type QueryCluster struct {
stopOnce sync.Once
}
type QueryNodeCreator func(ctx context.Context, addr string) (types.QueryNode, error)
type QueryNodeCreator func(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error)
func DefaultQueryNodeCreator(ctx context.Context, addr string) (types.QueryNode, error) {
return grpcquerynodeclient.NewClient(ctx, addr)
func DefaultQueryNodeCreator(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) {
return grpcquerynodeclient.NewClient(ctx, addr, nodeID)
}
func NewCluster(nodeManager *NodeManager, queryNodeCreator QueryNodeCreator) *QueryCluster {
@ -303,8 +303,8 @@ func (c *clients) getOrCreate(ctx context.Context, node *NodeInfo) (types.QueryN
return c.create(node)
}
func createNewClient(ctx context.Context, addr string, queryNodeCreator QueryNodeCreator) (types.QueryNode, error) {
newCli, err := queryNodeCreator(ctx, addr)
func createNewClient(ctx context.Context, addr string, nodeID int64, queryNodeCreator QueryNodeCreator) (types.QueryNode, error) {
newCli, err := queryNodeCreator(ctx, addr, nodeID)
if err != nil {
return nil, err
}
@ -323,7 +323,7 @@ func (c *clients) create(node *NodeInfo) (types.QueryNode, error) {
if cli, ok := c.clients[node.ID()]; ok {
return cli, nil
}
cli, err := createNewClient(context.Background(), node.Addr(), c.queryNodeCreator)
cli, err := createNewClient(context.Background(), node.Addr(), node.ID(), c.queryNodeCreator)
if err != nil {
return nil, err
}

View File

@ -315,7 +315,7 @@ func (node *QueryNode) Init() error {
}
}
client, err := grpcquerynodeclient.NewClient(node.ctx, addr)
client, err := grpcquerynodeclient.NewClient(node.ctx, addr, nodeID)
if err != nil {
return nil, err
}

View File

@ -36,10 +36,10 @@ import (
"github.com/milvus-io/milvus/pkg/util/metricsinfo"
)
type proxyCreator func(ctx context.Context, addr string) (types.Proxy, error)
type proxyCreator func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error)
func DefaultProxyCreator(ctx context.Context, addr string) (types.Proxy, error) {
cli, err := grpcproxyclient.NewClient(ctx, addr)
func DefaultProxyCreator(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) {
cli, err := grpcproxyclient.NewClient(ctx, addr, nodeID)
if err != nil {
return nil, err
}
@ -107,7 +107,7 @@ func (p *proxyClientManager) updateProxyNumMetric() {
}
func (p *proxyClientManager) connect(session *sessionutil.Session) {
pc, err := p.creator(context.Background(), session.Address)
pc, err := p.creator(context.Background(), session.Address, session.ServerID)
if err != nil {
log.Warn("failed to create proxy client", zap.String("address", session.Address), zap.Int64("serverID", session.ServerID), zap.Error(err))
return

View File

@ -116,7 +116,7 @@ func TestProxyClientManager_GetProxyClients(t *testing.T) {
defer cli.Close()
assert.NoError(t, err)
core.etcdCli = cli
core.proxyCreator = func(ctx context.Context, addr string) (types.Proxy, error) {
core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) {
return nil, errors.New("failed")
}
@ -148,7 +148,7 @@ func TestProxyClientManager_AddProxyClient(t *testing.T) {
defer cli.Close()
core.etcdCli = cli
core.proxyCreator = func(ctx context.Context, addr string) (types.Proxy, error) {
core.proxyCreator = func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) {
return nil, errors.New("failed")
}

View File

@ -252,7 +252,7 @@ func (c *Core) tsLoop() {
}
}
func (c *Core) SetProxyCreator(f func(ctx context.Context, addr string) (types.Proxy, error)) {
func (c *Core) SetProxyCreator(f func(ctx context.Context, addr string, nodeID int64) (types.Proxy, error)) {
c.proxyCreator = f
}

View File

@ -396,10 +396,10 @@ type DataCoordComponent interface {
SetRootCoord(rootCoord RootCoord)
// SetDataNodeCreator set DataNode client creator func for DataCoord
SetDataNodeCreator(func(context.Context, string) (DataNode, error))
SetDataNodeCreator(func(context.Context, string, int64) (DataNode, error))
//SetIndexNodeCreator set Index client creator func for DataCoord
SetIndexNodeCreator(func(context.Context, string) (IndexNode, error))
SetIndexNodeCreator(func(context.Context, string, int64) (IndexNode, error))
}
// IndexNode is the interface `indexnode` package implements
@ -837,7 +837,7 @@ type RootCoordComponent interface {
SetQueryCoord(queryCoord QueryCoord) error
// SetProxyCreator set Proxy client creator func for RootCoord
SetProxyCreator(func(ctx context.Context, addr string) (Proxy, error))
SetProxyCreator(func(ctx context.Context, addr string, nodeID int64) (Proxy, error))
// GetMetrics notifies RootCoordComponent to collect metrics for specified component
GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error)
@ -917,7 +917,7 @@ type ProxyComponent interface {
SetQueryCoordClient(queryCoord QueryCoord)
// SetQueryNodeCreator set QueryNode client creator func for Proxy
SetQueryNodeCreator(func(ctx context.Context, addr string) (QueryNode, error))
SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (QueryNode, error))
// GetRateLimiter returns the rateLimiter in Proxy
GetRateLimiter() (Limiter, error)
@ -1542,5 +1542,5 @@ type QueryCoordComponent interface {
SetRootCoord(rootCoord RootCoord) error
// SetQueryNodeCreator set QueryNode client creator func for QueryCoord
SetQueryNodeCreator(func(ctx context.Context, addr string) (QueryNode, error))
SetQueryNodeCreator(func(ctx context.Context, addr string, nodeID int64) (QueryNode, error))
}

View File

@ -201,10 +201,12 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
otelgrpc.UnaryClientInterceptor(opts...),
interceptor.ClusterInjectionUnaryClientInterceptor(),
interceptor.ServerIDInjectionUnaryClientInterceptor(c.GetNodeID()),
)),
grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(
otelgrpc.StreamClientInterceptor(opts...),
interceptor.ClusterInjectionStreamClientInterceptor(),
interceptor.ServerIDInjectionStreamClientInterceptor(c.GetNodeID()),
)),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
@ -239,10 +241,12 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(
otelgrpc.UnaryClientInterceptor(opts...),
interceptor.ClusterInjectionUnaryClientInterceptor(),
interceptor.ServerIDInjectionUnaryClientInterceptor(c.GetNodeID()),
)),
grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(
otelgrpc.StreamClientInterceptor(opts...),
interceptor.ClusterInjectionStreamClientInterceptor(),
interceptor.ServerIDInjectionStreamClientInterceptor(c.GetNodeID()),
)),
grpc.WithDefaultServiceConfig(retryPolicy),
grpc.WithKeepaliveParams(keepalive.ClientParameters{
@ -279,6 +283,7 @@ func (c *ClientBase[T]) connect(ctx context.Context) error {
}
func (c *ClientBase[T]) callOnce(ctx context.Context, caller func(client T) (any, error)) (any, error) {
log := log.Ctx(ctx).With(zap.String("role", c.GetRole()))
client, err := c.GetGrpcClient(ctx)
if err != nil {
return generic.Zero[T](), err
@ -295,21 +300,20 @@ func (c *ClientBase[T]) callOnce(ctx context.Context, caller func(client T) (any
return generic.Zero[T](), err
}
if IsCrossClusterRoutingErr(err) {
log.Ctx(ctx).Warn("CrossClusterRoutingErr, start to reset connection",
zap.String("role", c.GetRole()),
zap.Error(err),
)
log.Warn("CrossClusterRoutingErr, start to reset connection", zap.Error(err))
c.resetConnection(client)
return ret, merr.ErrServiceUnavailable // For concealing ErrCrossClusterRouting from the client
}
if IsServerIDMismatchErr(err) {
log.Warn("Server ID mismatch, start to reset connection", zap.Error(err))
c.resetConnection(client)
return ret, err
}
if !funcutil.IsGrpcErr(err) {
log.Ctx(ctx).Warn("ClientBase:isNotGrpcErr", zap.Error(err))
log.Warn("ClientBase:isNotGrpcErr", zap.Error(err))
return generic.Zero[T](), err
}
log.Ctx(ctx).Info("ClientBase grpc error, start to reset connection",
zap.String("role", c.GetRole()),
zap.Error(err),
)
log.Info("ClientBase grpc error, start to reset connection", zap.Error(err))
c.resetConnection(client)
return ret, err
}
@ -398,3 +402,9 @@ func IsCrossClusterRoutingErr(err error) bool {
// hence it is not viable to employ the `errors.Is` for assessment.
return strings.Contains(err.Error(), merr.ErrCrossClusterRouting.Error())
}
func IsServerIDMismatchErr(err error) bool {
// GRPC utilizes `status.Status` to encapsulate errors,
// hence it is not viable to employ the `errors.Is` for assessment.
return strings.Contains(err.Error(), merr.ErrServerIDMismatch.Error())
}

View File

@ -29,25 +29,6 @@ import (
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type mockSS struct {
grpc.ServerStream
ctx context.Context
}
func newMockSS(ctx context.Context) grpc.ServerStream {
return &mockSS{
ctx: ctx,
}
}
func (m *mockSS) Context() context.Context {
return m.ctx
}
func init() {
paramtable.Get().Init()
}
func TestClusterInterceptor(t *testing.T) {
t.Run("test ClusterInjectionUnaryClientInterceptor", func(t *testing.T) {
method := "MockMethod"

View File

@ -0,0 +1,44 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interceptor
import (
"context"
"google.golang.org/grpc"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
type mockSS struct {
grpc.ServerStream
ctx context.Context
}
func newMockSS(ctx context.Context) grpc.ServerStream {
return &mockSS{
ctx: ctx,
}
}
func (m *mockSS) Context() context.Context {
return m.ctx
}
func init() {
paramtable.Get().Init()
}

View File

@ -0,0 +1,95 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interceptor
import (
"context"
"fmt"
"strconv"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
const ServerIDKey = "ServerID"
// ServerIDValidationUnaryServerInterceptor returns a new unary server interceptor that
// verifies whether the target server ID of request matches with the server's ID and rejects it accordingly.
func ServerIDValidationUnaryServerInterceptor() grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return handler(ctx, req)
}
values := md.Get(ServerIDKey)
if len(values) == 0 {
return handler(ctx, req)
}
serverID, err := strconv.ParseInt(values[0], 10, 64)
if err != nil {
return handler(ctx, req)
}
if serverID != paramtable.GetNodeID() {
return nil, merr.WrapErrServerIDMismatch(serverID, paramtable.GetNodeID())
}
return handler(ctx, req)
}
}
// ServerIDValidationStreamServerInterceptor returns a new streaming server interceptor that
// verifies whether the target server ID of request matches with the server's ID and rejects it accordingly.
func ServerIDValidationStreamServerInterceptor() grpc.StreamServerInterceptor {
return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
md, ok := metadata.FromIncomingContext(ss.Context())
if !ok {
return handler(srv, ss)
}
values := md.Get(ServerIDKey)
if len(values) == 0 {
return handler(srv, ss)
}
serverID, err := strconv.ParseInt(values[0], 10, 64)
if err != nil {
return handler(srv, ss)
}
if serverID != paramtable.GetNodeID() {
return merr.WrapErrServerIDMismatch(serverID, paramtable.GetNodeID())
}
return handler(srv, ss)
}
}
// ServerIDInjectionUnaryClientInterceptor returns a new unary client interceptor that
// injects target server ID into the request.
func ServerIDInjectionUnaryClientInterceptor(targetServerID int64) grpc.UnaryClientInterceptor {
return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
ctx = metadata.AppendToOutgoingContext(ctx, ServerIDKey, fmt.Sprint(targetServerID))
return invoker(ctx, method, req, reply, cc, opts...)
}
}
// ServerIDInjectionStreamClientInterceptor returns a new streaming client interceptor that
// injects target server ID into the request.
func ServerIDInjectionStreamClientInterceptor(targetServerID int64) grpc.StreamClientInterceptor {
return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
ctx = metadata.AppendToOutgoingContext(ctx, ServerIDKey, fmt.Sprint(targetServerID))
return streamer(ctx, desc, cc, method, opts...)
}
}

View File

@ -0,0 +1,144 @@
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package interceptor
import (
"context"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/paramtable"
)
func TestServerIDInterceptor(t *testing.T) {
t.Run("test ServerIDInjectionUnaryClientInterceptor", func(t *testing.T) {
method := "MockMethod"
req := &milvuspb.InsertRequest{}
serverID := int64(1)
var incomingContext context.Context
invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
incomingContext = ctx
return nil
}
interceptor := ServerIDInjectionUnaryClientInterceptor(serverID)
ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))
err := interceptor(ctx, method, req, nil, nil, invoker)
assert.NoError(t, err)
md, ok := metadata.FromOutgoingContext(incomingContext)
assert.True(t, ok)
assert.Equal(t, fmt.Sprint(serverID), md.Get(ServerIDKey)[0])
})
t.Run("test ServerIDInjectionStreamClientInterceptor", func(t *testing.T) {
method := "MockMethod"
serverID := int64(1)
var incomingContext context.Context
streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
incomingContext = ctx
return nil, nil
}
interceptor := ServerIDInjectionStreamClientInterceptor(serverID)
ctx := metadata.NewOutgoingContext(context.Background(), metadata.New(make(map[string]string)))
_, err := interceptor(ctx, nil, nil, method, streamer)
assert.NoError(t, err)
md, ok := metadata.FromOutgoingContext(incomingContext)
assert.True(t, ok)
assert.Equal(t, fmt.Sprint(serverID), md.Get(ServerIDKey)[0])
})
t.Run("test ServerIDValidationUnaryServerInterceptor", func(t *testing.T) {
method := "MockMethod"
req := &milvuspb.InsertRequest{}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return nil, nil
}
serverInfo := &grpc.UnaryServerInfo{FullMethod: method}
interceptor := ServerIDValidationUnaryServerInterceptor()
// no md in context
_, err := interceptor(context.Background(), req, serverInfo, handler)
assert.NoError(t, err)
// no ServerID in md
ctx := metadata.NewIncomingContext(context.Background(), metadata.New(make(map[string]string)))
_, err = interceptor(ctx, req, serverInfo, handler)
assert.NoError(t, err)
// with invalid ServerID
md := metadata.Pairs(ServerIDKey, "@$#$%")
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, req, serverInfo, handler)
assert.NoError(t, err)
// with mismatch ServerID
md = metadata.Pairs(ServerIDKey, "1234")
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, req, serverInfo, handler)
assert.ErrorIs(t, err, merr.ErrServerIDMismatch)
// with same ServerID
md = metadata.Pairs(ServerIDKey, fmt.Sprint(paramtable.GetNodeID()))
ctx = metadata.NewIncomingContext(context.Background(), md)
_, err = interceptor(ctx, req, serverInfo, handler)
assert.NoError(t, err)
})
t.Run("test ServerIDValidationUnaryServerInterceptor", func(t *testing.T) {
handler := func(srv interface{}, stream grpc.ServerStream) error {
return nil
}
interceptor := ServerIDValidationStreamServerInterceptor()
// no md in context
err := interceptor(nil, newMockSS(context.Background()), nil, handler)
assert.NoError(t, err)
// no ServerID in md
ctx := metadata.NewIncomingContext(context.Background(), metadata.New(make(map[string]string)))
err = interceptor(nil, newMockSS(ctx), nil, handler)
assert.NoError(t, err)
// with invalid ServerID
md := metadata.Pairs(ServerIDKey, "@$#$%")
ctx = metadata.NewIncomingContext(context.Background(), md)
err = interceptor(nil, newMockSS(ctx), nil, handler)
assert.NoError(t, err)
// with mismatch ServerID
md = metadata.Pairs(ServerIDKey, "1234")
ctx = metadata.NewIncomingContext(context.Background(), md)
err = interceptor(nil, newMockSS(ctx), nil, handler)
assert.ErrorIs(t, err, merr.ErrServerIDMismatch)
// with same ServerID
md = metadata.Pairs(ServerIDKey, fmt.Sprint(paramtable.GetNodeID()))
ctx = metadata.NewIncomingContext(context.Background(), md)
err = interceptor(nil, newMockSS(ctx), nil, handler)
assert.NoError(t, err)
})
}

View File

@ -55,6 +55,7 @@ var (
ErrServiceInternal = newMilvusError("service internal error", 5, false) // Never return this error out of Milvus
ErrCrossClusterRouting = newMilvusError("cross cluster routing", 6, false)
ErrServiceDiskLimitExceeded = newMilvusError("disk limit exceeded", 7, false)
ErrServerIDMismatch = newMilvusError("server ID mismatch", 8, false)
// Collection related
ErrCollectionNotFound = newMilvusError("collection not found", 100, false)

View File

@ -77,6 +77,7 @@ func (s *ErrSuite) TestWrap() {
s.ErrorIs(WrapErrServiceInternal("never throw out"), ErrServiceInternal)
s.ErrorIs(WrapErrCrossClusterRouting("ins-0", "ins-1"), ErrCrossClusterRouting)
s.ErrorIs(WrapErrServiceDiskLimitExceeded(110, 100, "DLE"), ErrServiceDiskLimitExceeded)
s.ErrorIs(WrapErrServerIDMismatch(0, 1, "SIM"), ErrServerIDMismatch)
// Collection related
s.ErrorIs(WrapErrCollectionNotFound("test_collection", "failed to get collection"), ErrCollectionNotFound)

View File

@ -196,6 +196,14 @@ func WrapErrServiceDiskLimitExceeded(predict, limit float32, msg ...string) erro
return err
}
func WrapErrServerIDMismatch(expectedID, actualID int64, msg ...string) error {
err := errors.Wrapf(ErrServerIDMismatch, "expected=%s, actual=%s", expectedID, actualID)
if len(msg) > 0 {
err = errors.Wrap(err, strings.Join(msg, "; "))
}
return err
}
func WrapErrDatabaseNotFound(database any, msg ...string) error {
err := wrapWithField(ErrDatabaseNotfound, "database", database)
if len(msg) > 0 {

View File

@ -118,13 +118,13 @@ func (s *CrossClusterRoutingSuite) SetupTest() {
s.NoError(err)
s.queryCoordClient, err = grpcquerycoordclient.NewClient(s.ctx, metaRoot, s.client)
s.NoError(err)
s.proxyClient, err = grpcproxyclient.NewClient(s.ctx, paramtable.Get().ProxyGrpcClientCfg.GetInternalAddress())
s.proxyClient, err = grpcproxyclient.NewClient(s.ctx, paramtable.Get().ProxyGrpcClientCfg.GetInternalAddress(), 1)
s.NoError(err)
s.dataNodeClient, err = grpcdatanodeclient.NewClient(s.ctx, paramtable.Get().DataNodeGrpcClientCfg.GetAddress())
s.dataNodeClient, err = grpcdatanodeclient.NewClient(s.ctx, paramtable.Get().DataNodeGrpcClientCfg.GetAddress(), 1)
s.NoError(err)
s.queryNodeClient, err = grpcquerynodeclient.NewClient(s.ctx, paramtable.Get().QueryNodeGrpcClientCfg.GetAddress())
s.queryNodeClient, err = grpcquerynodeclient.NewClient(s.ctx, paramtable.Get().QueryNodeGrpcClientCfg.GetAddress(), 1)
s.NoError(err)
s.indexNodeClient, err = grpcindexnodeclient.NewClient(s.ctx, paramtable.Get().IndexNodeGrpcClientCfg.GetAddress(), false)
s.indexNodeClient, err = grpcindexnodeclient.NewClient(s.ctx, paramtable.Get().IndexNodeGrpcClientCfg.GetAddress(), 1, false)
s.NoError(err)
// setup servers

View File

@ -1179,7 +1179,7 @@ func (cluster *MiniCluster) UpdateClusterSize(clusterConfig ClusterConfig) error
return nil
}
func (cluster *MiniCluster) GetProxy(ctx context.Context, addr string) (types.Proxy, error) {
func (cluster *MiniCluster) GetProxy(ctx context.Context, addr string, nodeID int64) (types.Proxy, error) {
cluster.mu.RLock()
defer cluster.mu.RUnlock()
if cluster.Proxy.GetAddress() == addr {
@ -1188,7 +1188,7 @@ func (cluster *MiniCluster) GetProxy(ctx context.Context, addr string) (types.Pr
return nil, nil
}
func (cluster *MiniCluster) GetQueryNode(ctx context.Context, addr string) (types.QueryNode, error) {
func (cluster *MiniCluster) GetQueryNode(ctx context.Context, addr string, nodeID int64) (types.QueryNode, error) {
cluster.mu.RLock()
defer cluster.mu.RUnlock()
for _, queryNode := range cluster.QueryNodes {
@ -1199,7 +1199,7 @@ func (cluster *MiniCluster) GetQueryNode(ctx context.Context, addr string) (type
return nil, errors.New("no related queryNode found")
}
func (cluster *MiniCluster) GetDataNode(ctx context.Context, addr string) (types.DataNode, error) {
func (cluster *MiniCluster) GetDataNode(ctx context.Context, addr string, nodeID int64) (types.DataNode, error) {
cluster.mu.RLock()
defer cluster.mu.RUnlock()
for _, dataNode := range cluster.DataNodes {
@ -1210,7 +1210,7 @@ func (cluster *MiniCluster) GetDataNode(ctx context.Context, addr string) (types
return nil, errors.New("no related dataNode found")
}
func (cluster *MiniCluster) GetIndexNode(ctx context.Context, addr string) (types.IndexNode, error) {
func (cluster *MiniCluster) GetIndexNode(ctx context.Context, addr string, nodeID int64) (types.IndexNode, error) {
cluster.mu.RLock()
defer cluster.mu.RUnlock()
for _, indexNode := range cluster.IndexNodes {