From af22977cd3ea662c81a43a3b4fd1de4ed7c5f2ed Mon Sep 17 00:00:00 2001 From: "zhenshan.cao" Date: Fri, 3 Dec 2021 15:15:32 +0800 Subject: [PATCH] Fixbug: grpc connection is closed by mistake (#12307) Signed-off-by: zhenshan.cao --- .../distributed/datacoord/client/client.go | 394 ++++---------- .../datacoord/client/client_test.go | 134 ++--- .../distributed/datanode/client/client.go | 232 ++------- .../datanode/client/client_test.go | 68 +-- .../distributed/indexcoord/client/client.go | 285 +++------- .../indexcoord/client/client_test.go | 1 + .../distributed/indexnode/client/client.go | 220 ++------ .../indexnode/client/client_test.go | 60 +-- internal/distributed/proxy/client/client.go | 201 +------ .../distributed/proxy/client/client_test.go | 61 +-- .../distributed/querycoord/client/client.go | 327 +++--------- .../querycoord/client/client_test.go | 99 +--- .../distributed/querynode/client/client.go | 271 ++-------- .../querynode/client/client_test.go | 95 +--- .../distributed/rootcoord/client/client.go | 493 +++++------------- .../rootcoord/client/client_test.go | 147 +----- internal/querycoord/index_checker.go | 12 +- internal/querycoord/querynode.go | 22 +- internal/querycoord/task.go | 41 +- internal/querynode/data_sync_service_test.go | 6 +- .../querynode/flow_graph_query_node_test.go | 4 +- .../flow_graph_service_time_node_test.go | 2 +- internal/querynode/historical_test.go | 12 +- internal/querynode/mock_test.go | 2 +- internal/querynode/plan_test.go | 2 +- internal/querynode/query_collection_test.go | 4 +- internal/querynode/query_node.go | 4 +- internal/querynode/query_node_test.go | 2 +- internal/querynode/query_service_test.go | 2 +- internal/querynode/segment_test.go | 4 +- internal/querynode/streaming_test.go | 16 +- internal/querynode/tsafe_replica.go | 7 +- internal/querynode/tsafe_replica_test.go | 5 +- internal/util/grpcclient/client.go | 230 ++++++++ internal/util/grpcclient/client_test.go | 19 + internal/util/mock/datacoord_client.go | 119 +++++ internal/util/mock/datanode_client.go | 56 ++ internal/util/mock/grpcclient.go | 153 ++++++ internal/util/mock/indexnode_client.go | 52 ++ internal/util/mock/proxy_client.go | 52 ++ internal/util/mock/querycoord_client.go | 88 ++++ internal/util/mock/querynode_client.go | 83 +++ internal/util/mock/rootcoord_client.go | 136 +++++ 43 files changed, 1769 insertions(+), 2454 deletions(-) create mode 100644 internal/util/grpcclient/client.go create mode 100644 internal/util/grpcclient/client_test.go create mode 100644 internal/util/mock/datacoord_client.go create mode 100644 internal/util/mock/datanode_client.go create mode 100644 internal/util/mock/grpcclient.go create mode 100644 internal/util/mock/indexnode_client.go create mode 100644 internal/util/mock/proxy_client.go create mode 100644 internal/util/mock/querycoord_client.go create mode 100644 internal/util/mock/querynode_client.go create mode 100644 internal/util/mock/rootcoord_client.go diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index e205355325..8fecf271de 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -19,87 +19,57 @@ package grpcdatacoordclient import ( "context" "fmt" - "sync" - "time" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/sessionutil" ) // Client is the datacoord grpc client type Client struct { - ctx context.Context - cancel context.CancelFunc - - grpcClient datapb.DataCoordClient - conn *grpc.ClientConn - grpcClientMtx sync.RWMutex - - sess *sessionutil.Session - addr string - - getGrpcClient func() (datapb.DataCoordClient, error) + grpcClient grpcclient.GrpcClient + sess *sessionutil.Session } -func (c *Client) setGetGrpcClientFunc() { - c.getGrpcClient = c.getGrpcClientFunc -} - -func (c *Client) getGrpcClientFunc() (datapb.DataCoordClient, error) { - c.grpcClientMtx.RLock() - if c.grpcClient != nil { - defer c.grpcClientMtx.RUnlock() - return c.grpcClient, nil - } - c.grpcClientMtx.RUnlock() - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.grpcClient != nil { - return c.grpcClient, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { +// NewClient creates a new client instance +func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*Client, error) { + sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) + if sess == nil { + err := fmt.Errorf("new session error, maybe can not connect to etcd") + log.Debug("DataCoordClient NewClient failed", zap.Error(err)) return nil, err } - - return c.grpcClient, nil -} - -func (c *Client) resetConnection() { - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.conn != nil { - _ = c.conn.Close() + Params.Init() + client := &Client{ + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, + sess: sess, } - c.conn = nil - c.grpcClient = nil + client.grpcClient.SetRole(typeutil.DataCoordRole) + client.grpcClient.SetGetAddrFunc(client.getDataCoordAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) + + return client, nil } -func getDataCoordAddress(sess *sessionutil.Session) (string, error) { - key := typeutil.DataCoordRole - msess, _, err := sess.GetSessions(key) +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return datapb.NewDataCoordClient(cc) +} + +func (c *Client) getDataCoordAddr() (string, error) { + key := c.grpcClient.GetRole() + msess, _, err := c.sess.GetSessions(key) if err != nil { log.Debug("DataCoordClient, getSessions failed", zap.Any("key", key), zap.Error(err)) return "", err @@ -112,141 +82,33 @@ func getDataCoordAddress(sess *sessionutil.Session) (string, error) { return ms.Address, nil } -// NewClient creates a new client instance -func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*Client, error) { - sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) - if sess == nil { - err := fmt.Errorf("new session error, maybe can not connect to etcd") - log.Debug("DataCoordClient NewClient failed", zap.Error(err)) - return nil, err - } - ctx, cancel := context.WithCancel(ctx) - client := &Client{ - ctx: ctx, - cancel: cancel, - sess: sess, - } - - client.setGetGrpcClientFunc() - return client, nil -} - // Init initializes the client func (c *Client) Init() error { - Params.Init() return nil } -func (c *Client) connect(retryOptions ...retry.Option) error { - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - - var err error - connectDataCoordFn := func() error { - c.addr, err = getDataCoordAddress(c.sess) - if err != nil { - log.Debug("DataCoordClient getDataCoordAddr failed", zap.Error(err)) - return err - } - opts := trace.GetInterceptorOpts() - log.Debug("DataCoordClient try reconnect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor(grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = conn - return nil - } - - err = retry.Do(c.ctx, connectDataCoordFn, retryOptions...) - if err != nil { - log.Debug("DataCoord try reconnect failed", zap.Error(err)) - return err - } - c.grpcClient = datapb.NewDataCoordClient(c.conn) - return nil -} - -func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - - log.Debug("DataCoord Client grpc error", zap.Error(err)) - - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - - return ret, err -} - // Start enables the client func (c *Client) Start() error { return nil } func (c *Client) Stop() error { - c.cancel() - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - return c.conn.Close() - } - return nil + return c.grpcClient.Close() } -// Register dumy +// Register dummy func (c *Client) Register() error { return nil } func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + log.Debug("ABC", zap.Any("ctx", ctx), zap.Any("func", c.grpcClient.ReCall)) + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + log.Debug("ABC", zap.Any("client", client)) + return client.(datapb.DataCoordClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -256,15 +118,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS // GetTimeTickChannel return the name of time tick channel. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.(datapb.DataCoordClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -274,15 +132,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel return the name of statistics channel. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(datapb.DataCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -291,15 +145,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp } func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.Flush(ctx, req) + return client.(datapb.DataCoordClient).Flush(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -321,15 +171,11 @@ func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F // if the VChannel is newly used, `WatchDmlChannels` will be invoked to notify a `DataNode`(selected by policy) to watch it // if there is anything make the allocation impossible, the response will not contain the corresponding result func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.AssignSegmentID(ctx, req) + return client.(datapb.DataCoordClient).AssignSegmentID(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -347,15 +193,11 @@ func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI // otherwise the Segment State and Start position information will be returned // error is returned only when some communication issue occurs func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetSegmentStates(ctx, req) + return client.(datapb.DataCoordClient).GetSegmentStates(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -372,15 +214,11 @@ func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentSta // and corresponding binlog path list // error is returned only when some communication issue occurs func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetInsertBinlogPaths(ctx, req) + return client.(datapb.DataCoordClient).GetInsertBinlogPaths(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -397,15 +235,11 @@ func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsert // only row count for now // error is returned only when some communication issue occurs func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetCollectionStatistics(ctx, req) + return client.(datapb.DataCoordClient).GetCollectionStatistics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -422,15 +256,11 @@ func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCol // only row count for now // error is returned only when some communication issue occurs func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetPartitionStatistics(ctx, req) + return client.(datapb.DataCoordClient).GetPartitionStatistics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -441,15 +271,11 @@ func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart // GetSegmentInfoChannel DEPRECATED // legacy api to get SegmentInfo Channel name func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) + return client.(datapb.DataCoordClient).GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -465,15 +291,11 @@ func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringRes // response struct `GetSegmentInfoResponse` contains the list of segment info // error is returned only when some communication issue occurs func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetSegmentInfo(ctx, req) + return client.(datapb.DataCoordClient).GetSegmentInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -494,16 +316,18 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR // the root reason is each `SaveBinlogPaths` will overwrite the checkpoint position // if the constraint is broken, the checkpoint position will not be monotonically increasing and the integrity will be compromised func (c *Client) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { - // FIXME(dragondriver): why not to recall here? - client, err := c.getGrpcClient() - if err != nil { - return &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_UnexpectedError, - Reason: err.Error(), - }, nil + // use Call here on purpose + ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + return client.(datapb.DataCoordClient).SaveBinlogPaths(ctx, req) + }) + log.Debug("abc,", zap.Any("ret", ret), zap.Error(err)) + if err != nil || ret == nil { + return nil, err } - - return client.SaveBinlogPaths(ctx, req) + return ret.(*commonpb.Status), err } // GetRecoveryInfo request segment recovery info of collection/partition @@ -514,15 +338,11 @@ func (c *Client) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath // response struct `GetRecoveryInfoResponse` contains the list of segments info and corresponding vchannel info // error is returned only when some communication issue occurs func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetRecoveryInfo(ctx, req) + return client.(datapb.DataCoordClient).GetRecoveryInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -539,15 +359,11 @@ func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf // response struct `GetFlushedSegmentsResponse` contains flushed segment id list // error is returned only when some communication issue occurs func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetFlushedSegments(ctx, req) + return client.(datapb.DataCoordClient).GetFlushedSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -557,15 +373,11 @@ func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS // GetMetrics gets all metrics of datacoord func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetMetrics(ctx, req) + return client.(datapb.DataCoordClient).GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -574,15 +386,11 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest } func (c *Client) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.CompleteCompaction(ctx, req) + return client.(datapb.DataCoordClient).CompleteCompaction(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -591,15 +399,11 @@ func (c *Client) CompleteCompaction(ctx context.Context, req *datapb.CompactionR } func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ManualCompaction(ctx, req) + return client.(datapb.DataCoordClient).ManualCompaction(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -608,15 +412,11 @@ func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompa } func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetCompactionState(ctx, req) + return client.(datapb.DataCoordClient).GetCompactionState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -625,15 +425,11 @@ func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompac } func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetCompactionStateWithPlans(ctx, req) + return client.(datapb.DataCoordClient).GetCompactionStateWithPlans(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -643,15 +439,11 @@ func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb. // WatchChannels notifies DataCoord to watch vchannels of a collection func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.WatchChannels(ctx, req) + return client.(datapb.DataCoordClient).WatchChannels(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -661,15 +453,11 @@ func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq // GetFlushState gets the flush state of multiple segments func (c *Client) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetFlushState(ctx, req) + return client.(datapb.DataCoordClient).GetFlushState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -679,15 +467,11 @@ func (c *Client) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateR // DropVirtualChannel drops virtual channel in datacoord. func (c *Client) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DropVirtualChannel(ctx, req) + return client.(datapb.DataCoordClient).DropVirtualChannel(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go index 7d417adc8a..00172ea9d2 100644 --- a/internal/distributed/datacoord/client/client_test.go +++ b/internal/distributed/datacoord/client/client_test.go @@ -21,106 +21,13 @@ import ( "errors" "testing" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/util/mock" + "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" "google.golang.org/grpc" ) -type MockDataCoordClient struct { - err error -} - -func (m *MockDataCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { - return &internalpb.ComponentStates{}, m.err -} - -func (m *MockDataCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockDataCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockDataCoordClient) Flush(ctx context.Context, in *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { - return &datapb.FlushResponse{}, m.err -} - -func (m *MockDataCoordClient) AssignSegmentID(ctx context.Context, in *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { - return &datapb.AssignSegmentIDResponse{}, m.err -} - -func (m *MockDataCoordClient) GetSegmentInfo(ctx context.Context, in *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { - return &datapb.GetSegmentInfoResponse{}, m.err -} - -func (m *MockDataCoordClient) GetSegmentStates(ctx context.Context, in *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error) { - return &datapb.GetSegmentStatesResponse{}, m.err -} - -func (m *MockDataCoordClient) GetInsertBinlogPaths(ctx context.Context, in *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error) { - return &datapb.GetInsertBinlogPathsResponse{}, m.err -} - -func (m *MockDataCoordClient) GetCollectionStatistics(ctx context.Context, in *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { - return &datapb.GetCollectionStatisticsResponse{}, m.err -} - -func (m *MockDataCoordClient) GetPartitionStatistics(ctx context.Context, in *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error) { - return &datapb.GetPartitionStatisticsResponse{}, m.err -} - -func (m *MockDataCoordClient) GetSegmentInfoChannel(ctx context.Context, in *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockDataCoordClient) SaveBinlogPaths(ctx context.Context, in *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockDataCoordClient) GetRecoveryInfo(ctx context.Context, in *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error) { - return &datapb.GetRecoveryInfoResponse{}, m.err -} - -func (m *MockDataCoordClient) GetFlushedSegments(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { - return &datapb.GetFlushedSegmentsResponse{}, m.err -} - -func (m *MockDataCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { - return &milvuspb.GetMetricsResponse{}, m.err -} - -func (m *MockDataCoordClient) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockDataCoordClient) ManualCompaction(ctx context.Context, in *milvuspb.ManualCompactionRequest, opts ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error) { - return &milvuspb.ManualCompactionResponse{}, m.err -} - -func (m *MockDataCoordClient) GetCompactionState(ctx context.Context, in *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error) { - return &milvuspb.GetCompactionStateResponse{}, m.err -} - -func (m *MockDataCoordClient) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error) { - return &milvuspb.GetCompactionPlansResponse{}, m.err -} - -func (m *MockDataCoordClient) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { - return &datapb.WatchChannelsResponse{}, m.err -} -func (m *MockDataCoordClient) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { - return &milvuspb.GetFlushStateResponse{}, m.err -} - -func (m *MockDataCoordClient) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error) { - return &datapb.DropVirtualChannelResponse{}, m.err -} - func Test_NewClient(t *testing.T) { proxy.Params.InitOnce() @@ -213,31 +120,50 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r21, err) } - client.getGrpcClient = func() (datapb.DataCoordClient, error) { - return &MockDataCoordClient{err: nil}, errors.New("dummy") + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: errors.New("dummy"), } + + newFunc1 := func(cc *grpc.ClientConn) interface{} { + return &mock.DataCoordClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc1) + checkFunc(false) // special case since this method didn't use recall() ret, err := client.SaveBinlogPaths(ctx, nil) - assert.NotNil(t, ret) - assert.Nil(t, err) + assert.Nil(t, ret) + assert.NotNil(t, err) - client.getGrpcClient = func() (datapb.DataCoordClient, error) { - return &MockDataCoordClient{err: errors.New("dummy")}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + newFunc2 := func(cc *grpc.ClientConn) interface{} { + return &mock.DataCoordClient{Err: errors.New("dummy")} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc2) checkFunc(false) // special case since this method didn't use recall() ret, err = client.SaveBinlogPaths(ctx, nil) - assert.NotNil(t, ret) + assert.Nil(t, ret) assert.NotNil(t, err) - client.getGrpcClient = func() (datapb.DataCoordClient, error) { - return &MockDataCoordClient{err: nil}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + newFunc3 := func(cc *grpc.ClientConn) interface{} { + return &mock.DataCoordClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc3) checkFunc(true) + // special case since this method didn't use recall() + ret, err = client.SaveBinlogPaths(ctx, nil) + assert.NotNil(t, ret) + assert.Nil(t, err) + err = client.Stop() assert.Nil(t, err) } diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index c0d2fc1759..8d82c4b1a2 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -19,184 +19,49 @@ package grpcdatanodeclient import ( "context" "fmt" - "sync" - "time" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/util/typeutil" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" - "github.com/milvus-io/milvus/internal/util/trace" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" - - "go.uber.org/zap" - "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/util/grpcclient" ) // Client is the grpc client for DataNode type Client struct { - ctx context.Context - cancel context.CancelFunc - - grpc datapb.DataNodeClient - conn *grpc.ClientConn - grpcMtx sync.RWMutex - - addr string - - retryOptions []retry.Option - - getGrpcClient func() (datapb.DataNodeClient, error) -} - -func (c *Client) setGetGrpcClientFunc() { - c.getGrpcClient = c.getGrpcClientFunc -} - -func (c *Client) getGrpcClientFunc() (datapb.DataNodeClient, error) { - c.grpcMtx.RLock() - if c.grpc != nil { - defer c.grpcMtx.RUnlock() - return c.grpc, nil - } - c.grpcMtx.RUnlock() - - c.grpcMtx.Lock() - defer c.grpcMtx.Unlock() - - if c.grpc != nil { - return c.grpc, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { - log.Debug("DatanodeClient try reconnect failed", zap.Error(err)) - return nil, err - } - - return c.grpc, nil -} - -func (c *Client) resetConnection() { - c.grpcMtx.Lock() - defer c.grpcMtx.Unlock() - - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = nil - c.grpc = nil + grpcClient grpcclient.GrpcClient + addr string } // NewClient creates a client for DataNode. -func NewClient(ctx context.Context, addr string, retryOptions ...retry.Option) (*Client, error) { +func NewClient(ctx context.Context, addr string) (*Client, error) { if addr == "" { return nil, fmt.Errorf("address is empty") } - - ctx, cancel := context.WithCancel(ctx) + Params.Init() client := &Client{ - ctx: ctx, - cancel: cancel, - addr: addr, - retryOptions: retryOptions, + addr: addr, + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, } + client.grpcClient.SetRole(typeutil.DataNodeRole) + client.grpcClient.SetGetAddrFunc(client.getAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) - client.setGetGrpcClientFunc() return client, nil } // Init initializes the client. func (c *Client) Init() error { - Params.Init() return nil } -func (c *Client) connect(retryOptions ...retry.Option) error { - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - - connectGrpcFunc := func() error { - opts := trace.GetInterceptorOpts() - log.Debug("DataNode connect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithDisableRetry(), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = conn - return nil - } - - err := retry.Do(c.ctx, connectGrpcFunc, retryOptions...) - if err != nil { - log.Debug("DataNodeClient try connect failed", zap.Error(err)) - return err - } - log.Debug("DataNodeClient connect success") - c.grpc = datapb.NewDataNodeClient(c.conn) - return nil -} - -func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - - log.Debug("DataNode Client grpc error", zap.Error(err)) - - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - return ret, err -} - // Start starts the client. // Currently, it does nothing. func (c *Client) Start() error { @@ -206,13 +71,7 @@ func (c *Client) Start() error { // Stop stops the client. // Currently, it closes the grpc connection with the DataNode. func (c *Client) Stop() error { - c.cancel() - c.grpcMtx.Lock() - defer c.grpcMtx.Unlock() - if c.conn != nil { - return c.conn.Close() - } - return nil + return c.grpcClient.Close() } // Register does nothing. @@ -220,17 +79,20 @@ func (c *Client) Register() error { return nil } -// GetComponentStates returns ComponentStates +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return datapb.NewDataNodeClient(cc) +} + +func (c *Client) getAddr() (string, error) { + return c.addr, nil +} + func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + return client.(datapb.DataNodeClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -241,15 +103,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS // GetStatisticsChannel return the statistics channel in string // Statistics channel contains statistics infos of query nodes, such as segment infos, memory infos func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(datapb.DataNodeClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -259,15 +117,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // WatchDmChannels create consumers on dmChannels to reveive Incremental data func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.WatchDmChannels(ctx, req) + return client.(datapb.DataNodeClient).WatchDmChannels(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -284,15 +138,11 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel // Return Success code in status and trigers background flush: // Log an info log if a segment is under flushing func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.FlushSegments(ctx, req) + return client.(datapb.DataNodeClient).FlushSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -301,15 +151,11 @@ func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsReq } func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetMetrics(ctx, req) + return client.(datapb.DataNodeClient).GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -318,15 +164,11 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest } func (c *Client) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.Compaction(ctx, req) + return client.(datapb.DataNodeClient).Compaction(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/datanode/client/client_test.go b/internal/distributed/datanode/client/client_test.go index 6b64d62530..c351a6c889 100644 --- a/internal/distributed/datanode/client/client_test.go +++ b/internal/distributed/datanode/client/client_test.go @@ -21,46 +21,15 @@ import ( "errors" "testing" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/util/mock" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) -type MockDataNodeClient struct { - err error -} - -func (m *MockDataNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { - return &internalpb.ComponentStates{}, m.err -} - -func (m *MockDataNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockDataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockDataNodeClient) FlushSegments(ctx context.Context, in *datapb.FlushSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockDataNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { - return &milvuspb.GetMetricsResponse{}, m.err -} - -func (m *MockDataNodeClient) Compaction(ctx context.Context, req *datapb.CompactionPlan, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - func Test_NewClient(t *testing.T) { proxy.Params.InitOnce() - ctx := context.Background() client, err := NewClient(ctx, "") assert.Nil(t, client) @@ -109,19 +78,38 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r6, err) } - client.getGrpcClient = func() (datapb.DataNodeClient, error) { - return &MockDataNodeClient{err: nil}, errors.New("dummy") + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: errors.New("dummy"), } + + newFunc1 := func(cc *grpc.ClientConn) interface{} { + return &mock.DataNodeClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc1) + checkFunc(false) - client.getGrpcClient = func() (datapb.DataNodeClient, error) { - return &MockDataNodeClient{err: errors.New("dummy")}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc2 := func(cc *grpc.ClientConn) interface{} { + return &mock.DataNodeClient{Err: errors.New("dummy")} + } + + client.grpcClient.SetNewGrpcClientFunc(newFunc2) + checkFunc(false) - client.getGrpcClient = func() (datapb.DataNodeClient, error) { - return &MockDataNodeClient{err: nil}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc3 := func(cc *grpc.ClientConn) interface{} { + return &mock.DataNodeClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc3) + checkFunc(true) err = client.Stop() diff --git a/internal/distributed/indexcoord/client/client.go b/internal/distributed/indexcoord/client/client.go index d1d6111e30..83ef6037e6 100644 --- a/internal/distributed/indexcoord/client/client.go +++ b/internal/distributed/indexcoord/client/client.go @@ -19,81 +19,73 @@ package grpcindexcoordclient import ( "context" "fmt" - "sync" - "time" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/internal/util/trace" - "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/keepalive" + "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" ) // Client is the grpc client of IndexCoord. type Client struct { - ctx context.Context - cancel context.CancelFunc - - grpcClient indexpb.IndexCoordClient - conn *grpc.ClientConn - grpcClientMtx sync.RWMutex - - addr string - sess *sessionutil.Session + grpcClient grpcclient.GrpcClient + sess *sessionutil.Session } -func (c *Client) getGrpcClient() (indexpb.IndexCoordClient, error) { - c.grpcClientMtx.RLock() - if c.grpcClient != nil { - defer c.grpcClientMtx.RUnlock() - return c.grpcClient, nil - } - c.grpcClientMtx.RUnlock() - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.grpcClient != nil { - return c.grpcClient, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { - log.Debug("IndexcoordClient try reconnect failed", zap.Error(err)) +// NewClient creates a new IndexCoord client. +func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*Client, error) { + sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) + if sess == nil { + err := fmt.Errorf("new session error, maybe can not connect to etcd") + log.Debug("IndexCoordClient NewClient failed", zap.Error(err)) return nil, err } - - return c.grpcClient, nil -} - -func (c *Client) resetConnection() { - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.conn != nil { - _ = c.conn.Close() + Params.Init() + client := &Client{ + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, + sess: sess, } - c.conn = nil - c.grpcClient = nil + client.grpcClient.SetRole(typeutil.IndexCoordRole) + client.grpcClient.SetGetAddrFunc(client.getIndexCoordAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) + + return client, nil } -func getIndexCoordAddr(sess *sessionutil.Session) (string, error) { - key := typeutil.IndexCoordRole - msess, _, err := sess.GetSessions(key) +// Init initializes IndexCoord's grpc client. +func (c *Client) Init() error { + return nil +} + +// Start starts IndexCoord's client service. But it does nothing here. +func (c *Client) Start() error { + return nil +} + +// Stop stops IndexCoord's grpc client. +func (c *Client) Stop() error { + return c.grpcClient.Close() +} + +// Register dummy +func (c *Client) Register() error { + return nil +} + +func (c *Client) getIndexCoordAddr() (string, error) { + key := c.grpcClient.GetRole() + msess, _, err := c.sess.GetSessions(key) if err != nil { log.Debug("IndexCoordClient GetSessions failed", zap.Any("key", key), zap.Error(err)) return "", err @@ -107,134 +99,17 @@ func getIndexCoordAddr(sess *sessionutil.Session) (string, error) { return ms.Address, nil } -// NewClient creates a new IndexCoord client. -func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*Client, error) { - sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) - if sess == nil { - err := fmt.Errorf("new session error, maybe can not connect to etcd") - log.Debug("RootCoordClient NewClient failed", zap.Error(err)) - return nil, err - } - ctx, cancel := context.WithCancel(ctx) - return &Client{ - ctx: ctx, - cancel: cancel, - sess: sess, - }, nil -} - -// Init initializes IndexCoord's grpc client. -func (c *Client) Init() error { - Params.Init() - return nil -} - -func (c *Client) connect(retryOptions ...retry.Option) error { - var err error - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - connectIndexCoordaddrFn := func() error { - c.addr, err = getIndexCoordAddr(c.sess) - if err != nil { - log.Debug("IndexCoordClient getIndexCoordAddress failed") - return err - } - opts := trace.GetInterceptorOpts() - log.Debug("IndexCoordClient try connect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor(grpc_retry.WithMax(3)), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor(grpc_retry.WithMax(3)), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = conn - return nil - } - - err = retry.Do(c.ctx, connectIndexCoordaddrFn, retryOptions...) - if err != nil { - log.Debug("IndexCoordClient try connect failed", zap.Error(err)) - return err - } - log.Debug("IndexCoordClient connect success") - c.grpcClient = indexpb.NewIndexCoordClient(c.conn) - return nil -} - -func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - - log.Debug("IndexCoord Client grpc error", zap.Error(err)) - - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - return ret, err -} - -// Start starts IndexCoord's client service. But it does nothing here. -func (c *Client) Start() error { - return nil -} - -// Stop stops IndexCoord's grpc client. -func (c *Client) Stop() error { - c.cancel() - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - return c.conn.Close() - } - return nil -} - -// Register dummy -func (c *Client) Register() error { - return nil +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return indexpb.NewIndexCoordClient(cc) } // GetComponentStates gets the component states of IndexCoord. func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + return client.(indexpb.IndexCoordClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -244,15 +119,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS // GetTimeTickChannel gets the time tick channel of IndexCoord. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.(indexpb.IndexCoordClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -262,15 +133,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel gets the statistics channel of IndexCoord. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(indexpb.IndexCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -280,15 +147,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // BuildIndex sends the build index request to IndexCoord. func (c *Client) BuildIndex(ctx context.Context, req *indexpb.BuildIndexRequest) (*indexpb.BuildIndexResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.BuildIndex(ctx, req) + return client.(indexpb.IndexCoordClient).BuildIndex(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -298,15 +161,11 @@ func (c *Client) BuildIndex(ctx context.Context, req *indexpb.BuildIndexRequest) // DropIndex sends the drop index request to IndexCoord. func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DropIndex(ctx, req) + return client.(indexpb.IndexCoordClient).DropIndex(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -316,15 +175,11 @@ func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) ( // GetIndexStates gets the index states from IndexCoord. func (c *Client) GetIndexStates(ctx context.Context, req *indexpb.GetIndexStatesRequest) (*indexpb.GetIndexStatesResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetIndexStates(ctx, req) + return client.(indexpb.IndexCoordClient).GetIndexStates(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -334,15 +189,11 @@ func (c *Client) GetIndexStates(ctx context.Context, req *indexpb.GetIndexStates // GetIndexFilePaths gets the index file paths from IndexCoord. func (c *Client) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFilePathsRequest) (*indexpb.GetIndexFilePathsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetIndexFilePaths(ctx, req) + return client.(indexpb.IndexCoordClient).GetIndexFilePaths(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -352,15 +203,11 @@ func (c *Client) GetIndexFilePaths(ctx context.Context, req *indexpb.GetIndexFil // GetMetrics gets the metrics info of IndexCoord. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetMetrics(ctx, req) + return client.(indexpb.IndexCoordClient).GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/indexcoord/client/client_test.go b/internal/distributed/indexcoord/client/client_test.go index 81ef91b783..a091a5cb2b 100644 --- a/internal/distributed/indexcoord/client/client_test.go +++ b/internal/distributed/indexcoord/client/client_test.go @@ -30,6 +30,7 @@ import ( ) func TestIndexCoordClient(t *testing.T) { + Params.Init() ctx := context.Background() server, err := grpcindexcoord.NewServer(ctx) assert.Nil(t, err) diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index 5438e8018a..4fb3b1426f 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -19,80 +19,21 @@ package grpcindexnodeclient import ( "context" "fmt" - "sync" - "time" - - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/milvus-io/milvus/internal/log" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" - "github.com/milvus-io/milvus/internal/util/trace" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/typeutil" + "google.golang.org/grpc" ) // Client is the grpc client of IndexNode. type Client struct { - ctx context.Context - cancel context.CancelFunc - - grpcClient indexpb.IndexNodeClient - conn *grpc.ClientConn - grpcClientMtx sync.RWMutex - - addr string - - getGrpcClient func() (indexpb.IndexNodeClient, error) -} - -func (c *Client) setGetGrpcClientFunc() { - c.getGrpcClient = c.getGrpcClientFunc -} - -func (c *Client) getGrpcClientFunc() (indexpb.IndexNodeClient, error) { - c.grpcClientMtx.RLock() - if c.grpcClient != nil { - defer c.grpcClientMtx.RUnlock() - return c.grpcClient, nil - } - c.grpcClientMtx.RUnlock() - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.grpcClient != nil { - return c.grpcClient, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { - log.Debug("IndexNodeClient try reconnect failed", zap.Error(err)) - return nil, err - } - - return c.grpcClient, nil -} - -func (c *Client) resetConnection() { - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = nil - c.grpcClient = nil + grpcClient grpcclient.GrpcClient + addr string } // NewClient creates a new IndexNode client. @@ -100,98 +41,25 @@ func NewClient(ctx context.Context, addr string) (*Client, error) { if addr == "" { return nil, fmt.Errorf("address is empty") } - ctx, cancel := context.WithCancel(ctx) - + Params.Init() client := &Client{ - ctx: ctx, - cancel: cancel, - addr: addr, + addr: addr, + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, } - - client.setGetGrpcClientFunc() + client.grpcClient.SetRole(typeutil.IndexNodeRole) + client.grpcClient.SetGetAddrFunc(client.getAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) return client, nil } // Init initializes IndexNode's grpc client. func (c *Client) Init() error { - Params.Init() return nil } -func (c *Client) connect(retryOptions ...retry.Option) error { - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - - connectGrpcFunc := func() error { - opts := trace.GetInterceptorOpts() - log.Debug("IndexNodeClient try connect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = conn - return nil - } - - err := retry.Do(c.ctx, connectGrpcFunc, retryOptions...) - if err != nil { - log.Debug("IndexNodeClient try connect failed", zap.Error(err)) - return err - } - log.Debug("IndexNodeClient try connect success", zap.String("address", c.addr)) - c.grpcClient = indexpb.NewIndexNodeClient(c.conn) - return nil -} - -func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - log.Debug("IndexNode Client grpc error", zap.Error(err)) - - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - return ret, err -} - // Start starts IndexNode's client service. But it does nothing here. func (c *Client) Start() error { return nil @@ -199,13 +67,7 @@ func (c *Client) Start() error { // Stop stops IndexNode's grpc client. func (c *Client) Stop() error { - c.cancel() - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - _ = c.conn.Close() - } - return nil + return c.grpcClient.Close() } // Register dummy @@ -213,17 +75,21 @@ func (c *Client) Register() error { return nil } +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return indexpb.NewIndexNodeClient(cc) +} + +func (c *Client) getAddr() (string, error) { + return c.addr, nil +} + // GetComponentStates gets the component states of IndexNode. func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + return client.(indexpb.IndexNodeClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -233,15 +99,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS // GetTimeTickChannel gets the time tick channel of IndexNode. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.(indexpb.IndexNodeClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -251,15 +113,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel gets the statistics channel of IndexNode. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(indexpb.IndexNodeClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -269,15 +127,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // CreateIndex sends the build index request to IndexNode. func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.CreateIndex(ctx, req) + return client.(indexpb.IndexNodeClient).CreateIndex(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -287,15 +141,11 @@ func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques // GetMetrics gets the metrics info of IndexNode. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetMetrics(ctx, req) + return client.(indexpb.IndexNodeClient).GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/indexnode/client/client_test.go b/internal/distributed/indexnode/client/client_test.go index cc5ad6728d..b8f39cc2e6 100644 --- a/internal/distributed/indexnode/client/client_test.go +++ b/internal/distributed/indexnode/client/client_test.go @@ -21,44 +21,20 @@ import ( "errors" "testing" + "github.com/milvus-io/milvus/internal/util/mock" + "google.golang.org/grpc" + grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode" "github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) -type MockIndexNodeClient struct { - err error -} - -func (m *MockIndexNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { - return &internalpb.ComponentStates{}, m.err -} - -func (m *MockIndexNodeClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockIndexNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockIndexNodeClient) CreateIndex(ctx context.Context, in *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockIndexNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { - return &milvuspb.GetMetricsResponse{}, m.err -} - func Test_NewClient(t *testing.T) { - proxy.Params.InitOnce() - + Params.Init() ctx := context.Background() client, err := NewClient(ctx, "") assert.Nil(t, client) @@ -104,19 +80,35 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r5, err) } - client.getGrpcClient = func() (indexpb.IndexNodeClient, error) { - return &MockIndexNodeClient{err: nil}, errors.New("dummy") + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: errors.New("dummy"), } + + newFunc1 := func(cc *grpc.ClientConn) interface{} { + return &mock.IndexNodeClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc1) + checkFunc(false) - client.getGrpcClient = func() (indexpb.IndexNodeClient, error) { - return &MockIndexNodeClient{err: errors.New("dummy")}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc2 := func(cc *grpc.ClientConn) interface{} { + return &mock.IndexNodeClient{Err: errors.New("dummy")} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc2) checkFunc(false) - client.getGrpcClient = func() (indexpb.IndexNodeClient, error) { - return &MockIndexNodeClient{err: nil}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc3 := func(cc *grpc.ClientConn) interface{} { + return &mock.IndexNodeClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc3) checkFunc(true) err = client.Stop() diff --git a/internal/distributed/proxy/client/client.go b/internal/distributed/proxy/client/client.go index a4ee47d9bf..b3b50d44db 100644 --- a/internal/distributed/proxy/client/client.go +++ b/internal/distributed/proxy/client/client.go @@ -19,78 +19,21 @@ package grpcproxyclient import ( "context" "fmt" - "sync" - "time" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" - "github.com/milvus-io/milvus/internal/util/trace" - "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/typeutil" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" ) // Client is the grpc client for Proxy type Client struct { - ctx context.Context - cancel context.CancelFunc - - grpcClient proxypb.ProxyClient - conn *grpc.ClientConn - grpcClientMtx sync.RWMutex - - addr string - - getGrpcClient func() (proxypb.ProxyClient, error) -} - -func (c *Client) setGetGrpcClientFunc() { - c.getGrpcClient = c.getGrpcClientFunc -} - -func (c *Client) getGrpcClientFunc() (proxypb.ProxyClient, error) { - c.grpcClientMtx.RLock() - if c.grpcClient != nil { - defer c.grpcClientMtx.RUnlock() - return c.grpcClient, nil - } - c.grpcClientMtx.RUnlock() - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.grpcClient != nil { - return c.grpcClient, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { - return nil, err - } - - return c.grpcClient, nil -} - -func (c *Client) resetConnection() { - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = nil - c.grpcClient = nil + grpcClient grpcclient.GrpcClient + addr string } // NewClient creates a new client instance @@ -98,99 +41,31 @@ func NewClient(ctx context.Context, addr string) (*Client, error) { if addr == "" { return nil, fmt.Errorf("address is empty") } - ctx, cancel := context.WithCancel(ctx) - + Params.Init() client := &Client{ - ctx: ctx, - cancel: cancel, - addr: addr, + addr: addr, + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, } - - client.setGetGrpcClientFunc() + client.grpcClient.SetRole(typeutil.ProxyRole) + client.grpcClient.SetGetAddrFunc(client.getAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) return client, nil } // Init initializes Proxy's grpc client. func (c *Client) Init() error { - Params.Init() - return c.connect(retry.Attempts(20)) -} - -func (c *Client) connect(retryOptions ...retry.Option) error { - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - connectGrpcFunc := func() error { - opts := trace.GetInterceptorOpts() - log.Debug("ProxyClient try connect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = conn - return nil - } - - err := retry.Do(c.ctx, connectGrpcFunc, retryOptions...) - if err != nil { - log.Debug("ProxyClient try connect failed", zap.Error(err)) - return err - } - log.Debug("ProxyClient connect success") - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - c.grpcClient = proxypb.NewProxyClient(c.conn) - return nil } -func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - log.Debug("Proxy Client grpc error", zap.Error(err)) +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return proxypb.NewProxyClient(cc) +} - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - return ret, err +func (c *Client) getAddr() (string, error) { + return c.addr, nil } // Start dummy @@ -200,13 +75,7 @@ func (c *Client) Start() error { // Stop stops the client, closes the connection func (c *Client) Stop() error { - c.cancel() - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - return c.conn.Close() - } - return nil + return c.grpcClient.Close() } // Register dummy @@ -216,15 +85,11 @@ func (c *Client) Register() error { // GetComponentStates get the component state. func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + return client.(proxypb.ProxyClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -233,15 +98,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS } func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(proxypb.ProxyClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -251,15 +112,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // InvalidateCollectionMetaCache invalidate collection meta cache func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.InvalidateCollectionMetaCache(ctx, req) + return client.(proxypb.ProxyClient).InvalidateCollectionMetaCache(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -269,15 +126,11 @@ func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb // ReleaseDQLMessageStream release dql message stream by request func (c *Client) ReleaseDQLMessageStream(ctx context.Context, req *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ReleaseDQLMessageStream(ctx, req) + return client.(proxypb.ProxyClient).ReleaseDQLMessageStream(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/proxy/client/client_test.go b/internal/distributed/proxy/client/client_test.go index daf1576716..f89c4204e3 100644 --- a/internal/distributed/proxy/client/client_test.go +++ b/internal/distributed/proxy/client/client_test.go @@ -21,39 +21,13 @@ import ( "errors" "testing" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/util/mock" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) -type MockProxyClient struct { - err error -} - -func (m *MockProxyClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { - return &internalpb.ComponentStates{}, m.err -} - -func (m *MockProxyClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockProxyClient) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockProxyClient) GetDdChannel(ctx context.Context, in *internalpb.GetDdChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockProxyClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - func Test_NewClient(t *testing.T) { proxy.Params.InitOnce() @@ -96,19 +70,36 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r4, err) } - client.getGrpcClient = func() (proxypb.ProxyClient, error) { - return &MockProxyClient{err: nil}, errors.New("dummy") + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: errors.New("dummy"), } + + newFunc1 := func(cc *grpc.ClientConn) interface{} { + return &mock.ProxyClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc1) + checkFunc(false) - client.getGrpcClient = func() (proxypb.ProxyClient, error) { - return &MockProxyClient{err: errors.New("dummy")}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc2 := func(cc *grpc.ClientConn) interface{} { + return &mock.ProxyClient{Err: errors.New("dummy")} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc2) checkFunc(false) - client.getGrpcClient = func() (proxypb.ProxyClient, error) { - return &MockProxyClient{err: nil}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc3 := func(cc *grpc.ClientConn) interface{} { + return &mock.ProxyClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc3) + checkFunc(true) err = client.Stop() diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index c0635d7cbe..2df37177d7 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -19,88 +19,57 @@ package grpcquerycoordclient import ( "context" "fmt" - "sync" - "time" - - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" - "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/internal/util/trace" - "github.com/milvus-io/milvus/internal/util/typeutil" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" + "go.uber.org/zap" + "google.golang.org/grpc" ) // Client is the grpc client of QueryCoord. type Client struct { - ctx context.Context - cancel context.CancelFunc - - grpcClient querypb.QueryCoordClient - conn *grpc.ClientConn - grpcClientMtx sync.RWMutex - - sess *sessionutil.Session - addr string - - getGrpcClient func() (querypb.QueryCoordClient, error) + grpcClient grpcclient.GrpcClient + sess *sessionutil.Session } -func (c *Client) setGetGrpcClientFunc() { - c.getGrpcClient = c.getGrpcClientFunc -} - -func (c *Client) getGrpcClientFunc() (querypb.QueryCoordClient, error) { - c.grpcClientMtx.RLock() - if c.grpcClient != nil { - defer c.grpcClientMtx.RUnlock() - return c.grpcClient, nil - } - c.grpcClientMtx.RUnlock() - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.grpcClient != nil { - return c.grpcClient, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { - log.Warn("QueryCoordClient try connect fail", zap.Error(err)) +// NewClient creates a client for QueryCoord grpc call. +func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*Client, error) { + sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) + if sess == nil { + err := fmt.Errorf("new session error, maybe can not connect to etcd") + log.Debug("QueryCoordClient NewClient failed", zap.Error(err)) return nil, err } - - return c.grpcClient, nil -} - -func (c *Client) resetConnection() { - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.conn != nil { - _ = c.conn.Close() + Params.Init() + client := &Client{ + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, + sess: sess, } - c.conn = nil - c.grpcClient = nil + client.grpcClient.SetRole(typeutil.QueryCoordRole) + client.grpcClient.SetGetAddrFunc(client.getQueryCoordAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) + + return client, nil } -func getQueryCoordAddress(sess *sessionutil.Session) (string, error) { - key := typeutil.QueryCoordRole - msess, _, err := sess.GetSessions(key) +// Init initializes QueryCoord's grpc client. +func (c *Client) Init() error { + return nil +} + +func (c *Client) getQueryCoordAddr() (string, error) { + key := c.grpcClient.GetRole() + msess, _, err := c.sess.GetSessions(key) if err != nil { log.Debug("QueryCoordClient GetSessions failed", zap.Error(err)) return "", err @@ -113,124 +82,18 @@ func getQueryCoordAddress(sess *sessionutil.Session) (string, error) { return ms.Address, nil } -// NewClient creates a client for QueryCoord grpc call. -func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*Client, error) { - sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) - if sess == nil { - err := fmt.Errorf("new session error, maybe can not connect to etcd") - log.Debug("QueryCoordClient NewClient failed", zap.Error(err)) - return nil, err - } - ctx, cancel := context.WithCancel(ctx) - client := &Client{ - ctx: ctx, - cancel: cancel, - sess: sess, - } - - client.setGetGrpcClientFunc() - return client, nil +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return querypb.NewQueryCoordClient(cc) } -// Init initializes QueryCoord's grpc client. -func (c *Client) Init() error { - Params.Init() - return nil -} - -func (c *Client) connect(retryOptions ...retry.Option) error { - var err error - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - connectQueryCoordAddressFn := func() error { - c.addr, err = getQueryCoordAddress(c.sess) - if err != nil { - log.Debug("QueryCoordClient getQueryCoordAddress failed", zap.Error(err)) - return err - } - opts := trace.GetInterceptorOpts() - log.Debug("QueryCoordClient try reconnect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = conn - return nil - } - - err = retry.Do(c.ctx, connectQueryCoordAddressFn, retryOptions...) - if err != nil { - log.Debug("QueryCoordClient try reconnect failed", zap.Error(err)) - return err - } - log.Debug("QueryCoordClient try reconnect success") - c.grpcClient = querypb.NewQueryCoordClient(c.conn) - return nil -} - -func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - log.Debug("QueryCoord Client grpc error", zap.Error(err)) - - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - return ret, err -} - -// Start starts QueryCoord's client service. But it does nothing here. +// Start starts QueryCoordinator's client service. But it does nothing here. func (c *Client) Start() error { return nil } -// Stop stops QueryCoord's grpc client server. +// Stop stops QueryCoordinator's grpc client server. func (c *Client) Stop() error { - c.cancel() - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - return c.conn.Close() - } - return nil + return c.grpcClient.Close() } // Register dummy @@ -240,15 +103,11 @@ func (c *Client) Register() error { // GetComponentStates gets the component states of QueryCoord. func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + return client.(querypb.QueryCoordClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -258,15 +117,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS // GetTimeTickChannel gets the time tick channel of QueryCoord. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.(querypb.QueryCoordClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -276,15 +131,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel gets the statistics channel of QueryCoord. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(querypb.QueryCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -294,15 +145,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // ShowCollections shows the collections in the QueryCoord. func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ShowCollections(ctx, req) + return client.(querypb.QueryCoordClient).ShowCollections(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -312,15 +159,11 @@ func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectio // LoadCollection loads the data of the specified collections in the QueryCoord. func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.LoadCollection(ctx, req) + return client.(querypb.QueryCoordClient).LoadCollection(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -330,15 +173,11 @@ func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollection // ReleaseCollection release the data of the specified collections in the QueryCoord. func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ReleaseCollection(ctx, req) + return client.(querypb.QueryCoordClient).ReleaseCollection(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -348,15 +187,11 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl // ShowPartitions shows the partitions in the QueryCoord. func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ShowPartitions(ctx, req) + return client.(querypb.QueryCoordClient).ShowPartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -366,15 +201,11 @@ func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions // LoadPartitions loads the data of the specified partitions in the QueryCoord. func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.LoadPartitions(ctx, req) + return client.(querypb.QueryCoordClient).LoadPartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -384,15 +215,11 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions // ReleasePartitions release the data of the specified partitions in the QueryCoord. func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ReleasePartitions(ctx, req) + return client.(querypb.QueryCoordClient).ReleasePartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -402,15 +229,11 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart // CreateQueryChannel creates the channels for querying in QueryCoord. func (c *Client) CreateQueryChannel(ctx context.Context, req *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.CreateQueryChannel(ctx, req) + return client.(querypb.QueryCoordClient).CreateQueryChannel(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -420,15 +243,11 @@ func (c *Client) CreateQueryChannel(ctx context.Context, req *querypb.CreateQuer // GetPartitionStates gets the states of the specified partition. func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetPartitionStates(ctx, req) + return client.(querypb.QueryCoordClient).GetPartitionStates(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -438,15 +257,11 @@ func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti // GetSegmentInfo gets the information of the specified segment from QueryCoord. func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetSegmentInfo(ctx, req) + return client.(querypb.QueryCoordClient).GetSegmentInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -456,15 +271,11 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo // LoadBalance migrate the sealed segments on the source node to the dst nodes. func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.LoadBalance(ctx, req) + return client.(querypb.QueryCoordClient).LoadBalance(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -474,15 +285,11 @@ func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques // GetMetrics gets the metrics information of QueryCoord. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetMetrics(ctx, req) + return client.(querypb.QueryCoordClient).GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/querycoord/client/client_test.go b/internal/distributed/querycoord/client/client_test.go index abee7d4e38..c138ef6c47 100644 --- a/internal/distributed/querycoord/client/client_test.go +++ b/internal/distributed/querycoord/client/client_test.go @@ -21,75 +21,13 @@ import ( "errors" "testing" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/mock" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) -type MockQueryCoordClient struct { - err error -} - -func (m *MockQueryCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { - return &internalpb.ComponentStates{}, m.err -} - -func (m *MockQueryCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockQueryCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockQueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) { - return &querypb.ShowCollectionsResponse{}, m.err -} - -func (m *MockQueryCoordClient) ShowPartitions(ctx context.Context, in *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) { - return &querypb.ShowPartitionsResponse{}, m.err -} - -func (m *MockQueryCoordClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryCoordClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryCoordClient) LoadCollection(ctx context.Context, in *querypb.LoadCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryCoordClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryCoordClient) CreateQueryChannel(ctx context.Context, in *querypb.CreateQueryChannelRequest, opts ...grpc.CallOption) (*querypb.CreateQueryChannelResponse, error) { - return &querypb.CreateQueryChannelResponse{}, m.err -} - -func (m *MockQueryCoordClient) GetPartitionStates(ctx context.Context, in *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error) { - return &querypb.GetPartitionStatesResponse{}, m.err -} - -func (m *MockQueryCoordClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { - return &querypb.GetSegmentInfoResponse{}, m.err -} - -func (m *MockQueryCoordClient) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { - return &milvuspb.GetMetricsResponse{}, m.err -} - func Test_NewClient(t *testing.T) { proxy.Params.InitOnce() @@ -167,19 +105,38 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r16, err) } - client.getGrpcClient = func() (querypb.QueryCoordClient, error) { - return &MockQueryCoordClient{err: nil}, errors.New("dummy") + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: errors.New("dummy"), } + + newFunc1 := func(cc *grpc.ClientConn) interface{} { + return &mock.QueryCoordClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc1) + checkFunc(false) - client.getGrpcClient = func() (querypb.QueryCoordClient, error) { - return &MockQueryCoordClient{err: errors.New("dummy")}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc2 := func(cc *grpc.ClientConn) interface{} { + return &mock.QueryCoordClient{Err: errors.New("dummy")} + } + + client.grpcClient.SetNewGrpcClientFunc(newFunc2) + checkFunc(false) - client.getGrpcClient = func() (querypb.QueryCoordClient, error) { - return &MockQueryCoordClient{err: nil}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc3 := func(cc *grpc.ClientConn) interface{} { + return &mock.QueryCoordClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc3) + checkFunc(true) err = client.Stop() diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index 59df81c998..0a2b3f54ef 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -19,76 +19,21 @@ package grpcquerynodeclient import ( "context" "fmt" - "sync" - "time" - "go.uber.org/zap" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" - - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" - "github.com/milvus-io/milvus/internal/util/trace" + "github.com/milvus-io/milvus/internal/util/grpcclient" + "github.com/milvus-io/milvus/internal/util/typeutil" + "google.golang.org/grpc" ) // Client is the grpc client of QueryNode. type Client struct { - ctx context.Context - cancel context.CancelFunc - - grpcClient querypb.QueryNodeClient - conn *grpc.ClientConn - grpcClientMtx sync.RWMutex - - addr string - - getGrpcClient func() (querypb.QueryNodeClient, error) -} - -func (c *Client) setGetGrpcClientFunc() { - c.getGrpcClient = c.getGrpcClientFunc -} - -func (c *Client) getGrpcClientFunc() (querypb.QueryNodeClient, error) { - c.grpcClientMtx.RLock() - if c.grpcClient != nil { - defer c.grpcClientMtx.RUnlock() - return c.grpcClient, nil - } - c.grpcClientMtx.RUnlock() - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.grpcClient != nil { - return c.grpcClient, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { - return nil, err - } - - return c.grpcClient, nil -} - -func (c *Client) resetConnection() { - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - c.conn = nil - c.grpcClient = nil + grpcClient grpcclient.GrpcClient + addr string } // NewClient creates a new QueryNode client. @@ -96,94 +41,26 @@ func NewClient(ctx context.Context, addr string) (*Client, error) { if addr == "" { return nil, fmt.Errorf("addr is empty") } - ctx, cancel := context.WithCancel(ctx) - + Params.Init() client := &Client{ - ctx: ctx, - cancel: cancel, - addr: addr, + addr: addr, + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, } + client.grpcClient.SetRole(typeutil.QueryNodeRole) + client.grpcClient.SetGetAddrFunc(client.getAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) - client.setGetGrpcClientFunc() return client, nil } // Init initializes QueryNode's grpc client. func (c *Client) Init() error { - Params.Init() - _, err := c.getGrpcClient() - return err -} - -func (c *Client) connect(retryOptions ...retry.Option) error { - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - connectGrpcFunc := func() error { - opts := trace.GetInterceptorOpts() - log.Debug("QueryNodeClient try connect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor(grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - c.conn = conn - return nil - } - - err := retry.Do(c.ctx, connectGrpcFunc, retryOptions...) - if err != nil { - log.Debug("QueryNodeClient try connect failed", zap.Error(err)) - return err - } - log.Debug("QueryNodeClient try connect success") - c.grpcClient = querypb.NewQueryNodeClient(c.conn) return nil } -func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - log.Debug("QueryNode Client grpc error", zap.Error(err)) - - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - return ret, err -} - // Start starts QueryNode's client service. But it does nothing here. func (c *Client) Start() error { return nil @@ -191,13 +68,7 @@ func (c *Client) Start() error { // Stop stops QueryNode's grpc client server. func (c *Client) Stop() error { - c.cancel() - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - return c.conn.Close() - } - return nil + return c.grpcClient.Close() } // Register dummy @@ -205,17 +76,21 @@ func (c *Client) Register() error { return nil } +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return querypb.NewQueryNodeClient(cc) +} + +func (c *Client) getAddr() (string, error) { + return c.addr, nil +} + // GetComponentStates gets the component states of QueryNode. func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + return client.(querypb.QueryNodeClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -225,15 +100,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentS // GetTimeTickChannel gets the time tick channel of QueryNode. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.(querypb.QueryNodeClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -243,15 +114,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel gets the statistics channel of QueryNode. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(querypb.QueryNodeClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -261,15 +128,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // AddQueryChannel adds query channel for QueryNode component. func (c *Client) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.AddQueryChannel(ctx, req) + return client.(querypb.QueryNodeClient).AddQueryChannel(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -279,15 +142,11 @@ func (c *Client) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChann // RemoveQueryChannel removes the query channel for QueryNode component. func (c *Client) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQueryChannelRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.RemoveQueryChannel(ctx, req) + return client.(querypb.QueryNodeClient).RemoveQueryChannel(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -297,15 +156,11 @@ func (c *Client) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQuer // WatchDmChannels watches the channels about data manipulation. func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.WatchDmChannels(ctx, req) + return client.(querypb.QueryNodeClient).WatchDmChannels(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -315,15 +170,11 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChanne // WatchDeltaChannels watches the channels about data manipulation. func (c *Client) WatchDeltaChannels(ctx context.Context, req *querypb.WatchDeltaChannelsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.WatchDeltaChannels(ctx, req) + return client.(querypb.QueryNodeClient).WatchDeltaChannels(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -333,15 +184,11 @@ func (c *Client) WatchDeltaChannels(ctx context.Context, req *querypb.WatchDelta // LoadSegments loads the segments to search. func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.LoadSegments(ctx, req) + return client.(querypb.QueryNodeClient).LoadSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -351,15 +198,11 @@ func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequ // ReleaseCollection releases the data of the specified collection in QueryNode. func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ReleaseCollection(ctx, req) + return client.(querypb.QueryNodeClient).ReleaseCollection(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -369,15 +212,11 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl // ReleasePartitions releases the data of the specified partitions in QueryNode. func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ReleasePartitions(ctx, req) + return client.(querypb.QueryNodeClient).ReleasePartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -387,15 +226,11 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart // ReleaseSegments releases the data of the specified segments in QueryNode. func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ReleaseSegments(ctx, req) + return client.(querypb.QueryNodeClient).ReleaseSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -405,15 +240,11 @@ func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmen // GetSegmentInfo gets the information of the specified segments in QueryNode. func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetSegmentInfo(ctx, req) + return client.(querypb.QueryNodeClient).GetSegmentInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -423,15 +254,11 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo // GetMetrics gets the metrics information of QueryNode. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetMetrics(ctx, req) + return client.(querypb.QueryNodeClient).GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/querynode/client/client_test.go b/internal/distributed/querynode/client/client_test.go index d2585b6555..15e0033630 100644 --- a/internal/distributed/querynode/client/client_test.go +++ b/internal/distributed/querynode/client/client_test.go @@ -21,70 +21,13 @@ import ( "errors" "testing" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/mock" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) -type MockQueryNodeClient struct { - err error -} - -func (m *MockQueryNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { - return &internalpb.ComponentStates{}, m.err -} - -func (m *MockQueryNodeClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} -func (m *MockQueryNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockQueryNodeClient) AddQueryChannel(ctx context.Context, in *querypb.AddQueryChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) RemoveQueryChannel(ctx context.Context, in *querypb.RemoveQueryChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) WatchDeltaChannels(ctx context.Context, in *querypb.WatchDeltaChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockQueryNodeClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { - return &querypb.GetSegmentInfoResponse{}, m.err -} - -func (m *MockQueryNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { - return &milvuspb.GetMetricsResponse{}, m.err -} - func Test_NewClient(t *testing.T) { proxy.Params.InitOnce() @@ -156,21 +99,39 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r13, err) } - client.getGrpcClient = func() (querypb.QueryNodeClient, error) { - return &MockQueryNodeClient{err: nil}, errors.New("dummy") + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: errors.New("dummy"), } + + newFunc1 := func(cc *grpc.ClientConn) interface{} { + return &mock.QueryNodeClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc1) + checkFunc(false) - client.getGrpcClient = func() (querypb.QueryNodeClient, error) { - return &MockQueryNodeClient{err: errors.New("dummy")}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc2 := func(cc *grpc.ClientConn) interface{} { + return &mock.QueryNodeClient{Err: errors.New("dummy")} + } + + client.grpcClient.SetNewGrpcClientFunc(newFunc2) + checkFunc(false) - client.getGrpcClient = func() (querypb.QueryNodeClient, error) { - return &MockQueryNodeClient{err: nil}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc3 := func(cc *grpc.ClientConn) interface{} { + return &mock.QueryNodeClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc3) + checkFunc(true) - err = client.Stop() assert.Nil(t, err) } diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index a93c0e9b2f..c873159e6a 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -19,12 +19,7 @@ package grpcrootcoordclient import ( "context" "fmt" - "sync" - "time" - grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/datapb" @@ -33,34 +28,58 @@ import ( "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/funcutil" - "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/grpcclient" "github.com/milvus-io/milvus/internal/util/sessionutil" - "github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/typeutil" "go.uber.org/zap" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/keepalive" ) -// GrpcClient grpc client -type GrpcClient struct { - ctx context.Context - cancel context.CancelFunc - - grpcClient rootcoordpb.RootCoordClient - conn *grpc.ClientConn - grpcClientMtx sync.RWMutex - - sess *sessionutil.Session - addr string - - getGrpcClient func() (rootcoordpb.RootCoordClient, error) +// Client grpc client +type Client struct { + grpcClient grpcclient.GrpcClient + sess *sessionutil.Session } -func getRootCoordAddr(sess *sessionutil.Session) (string, error) { - key := typeutil.RootCoordRole - msess, _, err := sess.GetSessions(key) +// NewClient create root coordinator client with specified ectd info and timeout +// ctx execution control context +// metaRoot is the path in etcd for root coordinator registration +// etcdEndpoints are the address list for etcd end points +// timeout is default setting for each grpc call +func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*Client, error) { + sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) + if sess == nil { + err := fmt.Errorf("new session error, maybe can not connect to etcd") + log.Debug("QueryCoordClient NewClient failed", zap.Error(err)) + return nil, err + } + Params.Init() + client := &Client{ + grpcClient: &grpcclient.ClientBase{ + ClientMaxRecvSize: Params.ClientMaxRecvSize, + ClientMaxSendSize: Params.ClientMaxSendSize, + }, + sess: sess, + } + client.grpcClient.SetRole(typeutil.RootCoordRole) + client.grpcClient.SetGetAddrFunc(client.getRootCoordAddr) + client.grpcClient.SetNewGrpcClientFunc(client.newGrpcClient) + + return client, nil +} + +// Init initialize grpc parameters +func (c *Client) Init() error { + return nil +} + +func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { + return rootcoordpb.NewRootCoordClient(cc) +} + +func (c *Client) getRootCoordAddr() (string, error) { + key := c.grpcClient.GetRole() + msess, _, err := c.sess.GetSessions(key) if err != nil { log.Debug("RootCoordClient GetSessions failed", zap.Any("key", key)) return "", err @@ -74,188 +93,28 @@ func getRootCoordAddr(sess *sessionutil.Session) (string, error) { return ms.Address, nil } -// NewClient create root coordinator client with specified ectd info and timeout -// ctx execution control context -// metaRoot is the path in etcd for root coordinator registration -// etcdEndpoints are the address list for etcd end points -// timeout is default setting for each grpc call -func NewClient(ctx context.Context, metaRoot string, etcdEndpoints []string) (*GrpcClient, error) { - sess := sessionutil.NewSession(ctx, metaRoot, etcdEndpoints) - if sess == nil { - err := fmt.Errorf("new session error, maybe can not connect to etcd") - log.Debug("RootCoordClient NewClient failed", zap.Error(err)) - return nil, err - } - ctx, cancel := context.WithCancel(ctx) - - client := &GrpcClient{ - ctx: ctx, - cancel: cancel, - sess: sess, - } - - client.setGetGrpcClientFunc() - return client, nil -} - -// Init initialize grpc parameters -func (c *GrpcClient) Init() error { - Params.Init() - return nil -} - -func (c *GrpcClient) connect(retryOptions ...retry.Option) error { - var err error - var kacp = keepalive.ClientParameters{ - Time: 60 * time.Second, // send pings every 60 seconds if there is no activity - Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead - PermitWithoutStream: true, // send pings even without active streams - } - connectRootCoordAddrFn := func() error { - c.addr, err = getRootCoordAddr(c.sess) - if err != nil { - log.Debug("RootCoordClient getRootCoordAddr failed", zap.Error(err)) - return err - } - opts := trace.GetInterceptorOpts() - log.Debug("RootCoordClient try reconnect ", zap.String("address", c.addr)) - ctx, cancel := context.WithTimeout(c.ctx, 15*time.Second) - defer cancel() - conn, err := grpc.DialContext(ctx, c.addr, - grpc.WithKeepaliveParams(kacp), - grpc.WithInsecure(), grpc.WithBlock(), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(Params.ClientMaxRecvSize), - grpc.MaxCallSendMsgSize(Params.ClientMaxSendSize)), - grpc.WithUnaryInterceptor( - grpc_middleware.ChainUnaryClient( - grpc_retry.UnaryClientInterceptor( - grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.UnaryClientInterceptor(opts...), - )), - grpc.WithStreamInterceptor( - grpc_middleware.ChainStreamClient( - grpc_retry.StreamClientInterceptor(grpc_retry.WithMax(3), - grpc_retry.WithCodes(codes.Aborted, codes.Unavailable), - ), - grpc_opentracing.StreamClientInterceptor(opts...), - )), - ) - if err != nil { - return err - } - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = conn - return nil - } - - err = retry.Do(c.ctx, connectRootCoordAddrFn, retryOptions...) - if err != nil { - log.Debug("RootCoordClient try reconnect failed", zap.Error(err)) - return err - } - log.Debug("RootCoordClient try reconnect success") - - c.grpcClient = rootcoordpb.NewRootCoordClient(c.conn) - - return nil -} - -func (c *GrpcClient) setGetGrpcClientFunc() { - c.getGrpcClient = c.getGrpcClientFunc -} - -func (c *GrpcClient) getGrpcClientFunc() (rootcoordpb.RootCoordClient, error) { - c.grpcClientMtx.RLock() - if c.grpcClient != nil { - defer c.grpcClientMtx.RUnlock() - return c.grpcClient, nil - } - c.grpcClientMtx.RUnlock() - - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - - if c.grpcClient != nil { - return c.grpcClient, nil - } - - // FIXME(dragondriver): how to handle error here? - // if we return nil here, then we should check if client is nil outside, - err := c.connect(retry.Attempts(20)) - if err != nil { - log.Debug("RoodCoordClient try connect failed", zap.Error(err)) - return nil, err - } - - return c.grpcClient, nil -} - // Start dummy -func (c *GrpcClient) Start() error { +func (c *Client) Start() error { return nil } // Stop terminate grpc connection -func (c *GrpcClient) Stop() error { - c.cancel() - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - return c.conn.Close() - } - return nil +func (c *Client) Stop() error { + return c.grpcClient.Close() } // Register dummy -func (c *GrpcClient) Register() error { +func (c *Client) Register() error { return nil } -func (c *GrpcClient) resetConnection() { - c.grpcClientMtx.Lock() - defer c.grpcClientMtx.Unlock() - if c.conn != nil { - _ = c.conn.Close() - } - c.conn = nil - c.grpcClient = nil -} - -func (c *GrpcClient) recall(caller func() (interface{}, error)) (interface{}, error) { - ret, err := caller() - if err == nil { - return ret, nil - } - if err == context.Canceled || err == context.DeadlineExceeded { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - log.Debug("RootCoord Client grpc error", zap.Error(err)) - - c.resetConnection() - - ret, err = caller() - if err != nil { - return nil, fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - } - return ret, err -} - // GetComponentStates TODO: timeout need to be propagated through ctx -func (c *GrpcClient) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + return client.(rootcoordpb.RootCoordClient).GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -264,16 +123,12 @@ func (c *GrpcClient) GetComponentStates(ctx context.Context) (*internalpb.Compon } // GetTimeTickChannel get timetick channel name -func (c *GrpcClient) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.(rootcoordpb.RootCoordClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -282,16 +137,12 @@ func (c *GrpcClient) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRe } // GetStatisticsChannel just define a channel, not used currently -func (c *GrpcClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.(rootcoordpb.RootCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -300,16 +151,12 @@ func (c *GrpcClient) GetStatisticsChannel(ctx context.Context) (*milvuspb.String } // CreateCollection create collection -func (c *GrpcClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.CreateCollection(ctx, in) + return client.(rootcoordpb.RootCoordClient).CreateCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -318,16 +165,12 @@ func (c *GrpcClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCo } // DropCollection drop collection -func (c *GrpcClient) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DropCollection(ctx, in) + return client.(rootcoordpb.RootCoordClient).DropCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -336,16 +179,12 @@ func (c *GrpcClient) DropCollection(ctx context.Context, in *milvuspb.DropCollec } // HasCollection check collection existence -func (c *GrpcClient) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.HasCollection(ctx, in) + return client.(rootcoordpb.RootCoordClient).HasCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -354,16 +193,12 @@ func (c *GrpcClient) HasCollection(ctx context.Context, in *milvuspb.HasCollecti } // DescribeCollection return collection info -func (c *GrpcClient) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DescribeCollection(ctx, in) + return client.(rootcoordpb.RootCoordClient).DescribeCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -372,16 +207,12 @@ func (c *GrpcClient) DescribeCollection(ctx context.Context, in *milvuspb.Descri } // ShowCollections list all collection names -func (c *GrpcClient) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ShowCollections(ctx, in) + return client.(rootcoordpb.RootCoordClient).ShowCollections(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -390,16 +221,12 @@ func (c *GrpcClient) ShowCollections(ctx context.Context, in *milvuspb.ShowColle } // CreatePartition create partition -func (c *GrpcClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.CreatePartition(ctx, in) + return client.(rootcoordpb.RootCoordClient).CreatePartition(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -408,16 +235,12 @@ func (c *GrpcClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePar } // DropPartition drop partition -func (c *GrpcClient) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DropPartition(ctx, in) + return client.(rootcoordpb.RootCoordClient).DropPartition(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -426,16 +249,12 @@ func (c *GrpcClient) DropPartition(ctx context.Context, in *milvuspb.DropPartiti } // HasPartition check partition existence -func (c *GrpcClient) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.HasPartition(ctx, in) + return client.(rootcoordpb.RootCoordClient).HasPartition(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -444,16 +263,12 @@ func (c *GrpcClient) HasPartition(ctx context.Context, in *milvuspb.HasPartition } // ShowPartitions list all partitions in collection -func (c *GrpcClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ShowPartitions(ctx, in) + return client.(rootcoordpb.RootCoordClient).ShowPartitions(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -462,16 +277,12 @@ func (c *GrpcClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartit } // CreateIndex create index -func (c *GrpcClient) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.CreateIndex(ctx, in) + return client.(rootcoordpb.RootCoordClient).CreateIndex(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -480,16 +291,12 @@ func (c *GrpcClient) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRe } // DropIndex drop index -func (c *GrpcClient) DropIndex(ctx context.Context, in *milvuspb.DropIndexRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) DropIndex(ctx context.Context, in *milvuspb.DropIndexRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DropIndex(ctx, in) + return client.(rootcoordpb.RootCoordClient).DropIndex(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -498,16 +305,12 @@ func (c *GrpcClient) DropIndex(ctx context.Context, in *milvuspb.DropIndexReques } // DescribeIndex return index info -func (c *GrpcClient) DescribeIndex(ctx context.Context, in *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) DescribeIndex(ctx context.Context, in *milvuspb.DescribeIndexRequest) (*milvuspb.DescribeIndexResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DescribeIndex(ctx, in) + return client.(rootcoordpb.RootCoordClient).DescribeIndex(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -516,16 +319,12 @@ func (c *GrpcClient) DescribeIndex(ctx context.Context, in *milvuspb.DescribeInd } // AllocTimestamp global timestamp allocator -func (c *GrpcClient) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.AllocTimestamp(ctx, in) + return client.(rootcoordpb.RootCoordClient).AllocTimestamp(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -534,16 +333,12 @@ func (c *GrpcClient) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTi } // AllocID global ID allocator -func (c *GrpcClient) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.AllocID(ctx, in) + return client.(rootcoordpb.RootCoordClient).AllocID(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -552,16 +347,12 @@ func (c *GrpcClient) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest } // UpdateChannelTimeTick used to handle ChannelTimeTickMsg -func (c *GrpcClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.UpdateChannelTimeTick(ctx, in) + return client.(rootcoordpb.RootCoordClient).UpdateChannelTimeTick(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -570,16 +361,12 @@ func (c *GrpcClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.C } // DescribeSegment receiver time tick from proxy service, and put it into this channel -func (c *GrpcClient) DescribeSegment(ctx context.Context, in *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) DescribeSegment(ctx context.Context, in *milvuspb.DescribeSegmentRequest) (*milvuspb.DescribeSegmentResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DescribeSegment(ctx, in) + return client.(rootcoordpb.RootCoordClient).DescribeSegment(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -588,16 +375,12 @@ func (c *GrpcClient) DescribeSegment(ctx context.Context, in *milvuspb.DescribeS } // ShowSegments list all segments -func (c *GrpcClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ShowSegments(ctx, in) + return client.(rootcoordpb.RootCoordClient).ShowSegments(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -606,16 +389,12 @@ func (c *GrpcClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegments } // ReleaseDQLMessageStream release DQL msgstream -func (c *GrpcClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.ReleaseDQLMessageStream(ctx, in) + return client.(rootcoordpb.RootCoordClient).ReleaseDQLMessageStream(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -624,16 +403,12 @@ func (c *GrpcClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.Re } // SegmentFlushCompleted check whether segment flush is completed -func (c *GrpcClient) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.SegmentFlushCompleted(ctx, in) + return client.(rootcoordpb.RootCoordClient).SegmentFlushCompleted(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -642,16 +417,12 @@ func (c *GrpcClient) SegmentFlushCompleted(ctx context.Context, in *datapb.Segme } // GetMetrics get metrics -func (c *GrpcClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.GetMetrics(ctx, in) + return client.(rootcoordpb.RootCoordClient).GetMetrics(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -660,16 +431,12 @@ func (c *GrpcClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequ } // CreateAlias create collection alias -func (c *GrpcClient) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.CreateAlias(ctx, req) + return client.(rootcoordpb.RootCoordClient).CreateAlias(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -678,16 +445,12 @@ func (c *GrpcClient) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasR } // DropAlias drop collection alias -func (c *GrpcClient) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.DropAlias(ctx, req) + return client.(rootcoordpb.RootCoordClient).DropAlias(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -696,16 +459,12 @@ func (c *GrpcClient) DropAlias(ctx context.Context, req *milvuspb.DropAliasReque } // AlterAlias alter collection alias -func (c *GrpcClient) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - ret, err := c.recall(func() (interface{}, error) { - client, err := c.getGrpcClient() - if err != nil { - return nil, err - } +func (c *Client) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.AlterAlias(ctx, req) + return client.(rootcoordpb.RootCoordClient).AlterAlias(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/rootcoord/client/client_test.go b/internal/distributed/rootcoord/client/client_test.go index 2a66d560a9..3c87e1d645 100644 --- a/internal/distributed/rootcoord/client/client_test.go +++ b/internal/distributed/rootcoord/client/client_test.go @@ -21,123 +21,13 @@ import ( "errors" "testing" - "github.com/milvus-io/milvus/internal/proto/commonpb" - "github.com/milvus-io/milvus/internal/proto/datapb" - "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/proto/milvuspb" - "github.com/milvus-io/milvus/internal/proto/proxypb" - "github.com/milvus-io/milvus/internal/proto/rootcoordpb" + "github.com/milvus-io/milvus/internal/util/mock" + "google.golang.org/grpc" + "github.com/milvus-io/milvus/internal/proxy" "github.com/stretchr/testify/assert" - "google.golang.org/grpc" ) -type MockRootCoordClient struct { - err error -} - -func (m *MockRootCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { - return &internalpb.ComponentStates{}, m.err -} -func (m *MockRootCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} -func (m *MockRootCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { - return &milvuspb.StringResponse{}, m.err -} - -func (m *MockRootCoordClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { - return &milvuspb.BoolResponse{}, m.err -} - -func (m *MockRootCoordClient) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { - return &milvuspb.DescribeCollectionResponse{}, m.err -} - -func (m *MockRootCoordClient) CreateAlias(ctx context.Context, in *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) AlterAlias(ctx context.Context, in *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { - return &milvuspb.ShowCollectionsResponse{}, m.err -} - -func (m *MockRootCoordClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { - return &milvuspb.BoolResponse{}, m.err -} - -func (m *MockRootCoordClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { - return &milvuspb.ShowPartitionsResponse{}, m.err -} - -func (m *MockRootCoordClient) DescribeSegment(ctx context.Context, in *milvuspb.DescribeSegmentRequest, opts ...grpc.CallOption) (*milvuspb.DescribeSegmentResponse, error) { - return &milvuspb.DescribeSegmentResponse{}, m.err -} - -func (m *MockRootCoordClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { - return &milvuspb.ShowSegmentsResponse{}, m.err -} - -func (m *MockRootCoordClient) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) DescribeIndex(ctx context.Context, in *milvuspb.DescribeIndexRequest, opts ...grpc.CallOption) (*milvuspb.DescribeIndexResponse, error) { - return &milvuspb.DescribeIndexResponse{}, m.err -} - -func (m *MockRootCoordClient) DropIndex(ctx context.Context, in *milvuspb.DropIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { - return &rootcoordpb.AllocTimestampResponse{}, m.err -} - -func (m *MockRootCoordClient) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { - return &rootcoordpb.AllocIDResponse{}, m.err -} - -func (m *MockRootCoordClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { - return &commonpb.Status{}, m.err -} - -func (m *MockRootCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { - return &milvuspb.GetMetricsResponse{}, m.err -} - func Test_NewClient(t *testing.T) { proxy.Params.InitOnce() @@ -245,19 +135,38 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r26, err) } - client.getGrpcClient = func() (rootcoordpb.RootCoordClient, error) { - return &MockRootCoordClient{err: nil}, errors.New("dummy") + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: errors.New("dummy"), } + + newFunc1 := func(cc *grpc.ClientConn) interface{} { + return &mock.RootCoordClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc1) + checkFunc(false) - client.getGrpcClient = func() (rootcoordpb.RootCoordClient, error) { - return &MockRootCoordClient{err: errors.New("dummy")}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc2 := func(cc *grpc.ClientConn) interface{} { + return &mock.RootCoordClient{Err: errors.New("dummy")} + } + + client.grpcClient.SetNewGrpcClientFunc(newFunc2) + checkFunc(false) - client.getGrpcClient = func() (rootcoordpb.RootCoordClient, error) { - return &MockRootCoordClient{err: nil}, nil + client.grpcClient = &mock.ClientBase{ + GetGrpcClientErr: nil, } + + newFunc3 := func(cc *grpc.ClientConn) interface{} { + return &mock.RootCoordClient{Err: nil} + } + client.grpcClient.SetNewGrpcClientFunc(newFunc3) + checkFunc(true) err = client.Stop() diff --git a/internal/querycoord/index_checker.go b/internal/querycoord/index_checker.go index 9d6a386834..560ba75017 100644 --- a/internal/querycoord/index_checker.go +++ b/internal/querycoord/index_checker.go @@ -291,7 +291,9 @@ func getIndexInfo(ctx context.Context, info *querypb.SegmentInfo, root types.Roo CollectionID: info.CollectionID, SegmentID: info.SegmentID, } - response, err := root.DescribeSegment(ctx, req) + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + response, err := root.DescribeSegment(ctx2, req) if err != nil { return nil, err } @@ -309,9 +311,11 @@ func getIndexInfo(ctx context.Context, info *querypb.SegmentInfo, root types.Roo indexFilePathRequest := &indexpb.GetIndexFilePathsRequest{ IndexBuildIDs: []UniqueID{response.BuildID}, } - pathResponse, err := index.GetIndexFilePaths(ctx, indexFilePathRequest) - if err != nil { - return nil, err + ctx3, cancel3 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel3() + pathResponse, err2 := index.GetIndexFilePaths(ctx3, indexFilePathRequest) + if err2 != nil { + return nil, err2 } if pathResponse.Status.ErrorCode != commonpb.ErrorCode_Success { diff --git a/internal/querycoord/querynode.go b/internal/querycoord/querynode.go index 063b059f71..517c27cfe6 100644 --- a/internal/querycoord/querynode.go +++ b/internal/querycoord/querynode.go @@ -430,7 +430,7 @@ func (qn *queryNode) watchDmChannels(ctx context.Context, in *querypb.WatchDmCha return errors.New("WatchDmChannels: queryNode is offline") } - status, err := qn.client.WatchDmChannels(ctx, in) + status, err := qn.client.WatchDmChannels(qn.ctx, in) if err != nil { return err } @@ -454,7 +454,7 @@ func (qn *queryNode) watchDeltaChannels(ctx context.Context, in *querypb.WatchDe return errors.New("WatchDmChannels: queryNode is offline") } - status, err := qn.client.WatchDeltaChannels(ctx, in) + status, err := qn.client.WatchDeltaChannels(qn.ctx, in) if err != nil { return err } @@ -470,7 +470,7 @@ func (qn *queryNode) addQueryChannel(ctx context.Context, in *querypb.AddQueryCh return errors.New("AddQueryChannel: queryNode is offline") } - status, err := qn.client.AddQueryChannel(ctx, in) + status, err := qn.client.AddQueryChannel(qn.ctx, in) if err != nil { return err } @@ -492,7 +492,7 @@ func (qn *queryNode) removeQueryChannel(ctx context.Context, in *querypb.RemoveQ return nil } - status, err := qn.client.RemoveQueryChannel(ctx, in) + status, err := qn.client.RemoveQueryChannel(qn.ctx, in) if err != nil { return err } @@ -510,7 +510,7 @@ func (qn *queryNode) releaseCollection(ctx context.Context, in *querypb.ReleaseC return nil } - status, err := qn.client.ReleaseCollection(ctx, in) + status, err := qn.client.ReleaseCollection(qn.ctx, in) if err != nil { return err } @@ -531,7 +531,7 @@ func (qn *queryNode) releasePartitions(ctx context.Context, in *querypb.ReleaseP return nil } - status, err := qn.client.ReleasePartitions(ctx, in) + status, err := qn.client.ReleasePartitions(qn.ctx, in) if err != nil { return err } @@ -551,7 +551,7 @@ func (qn *queryNode) getSegmentInfo(ctx context.Context, in *querypb.GetSegmentI return nil, fmt.Errorf("getSegmentInfo: queryNode %d is offline", qn.id) } - res, err := qn.client.GetSegmentInfo(ctx, in) + res, err := qn.client.GetSegmentInfo(qn.ctx, in) if err != nil { return nil, err } @@ -570,7 +570,7 @@ func (qn *queryNode) getComponentInfo(ctx context.Context) *internalpb.Component } } - res, err := qn.client.GetComponentStates(ctx) + res, err := qn.client.GetComponentStates(qn.ctx) if err != nil || res.Status.ErrorCode != commonpb.ErrorCode_Success { return &internalpb.ComponentInfo{ NodeID: qn.id, @@ -586,7 +586,7 @@ func (qn *queryNode) getMetrics(ctx context.Context, in *milvuspb.GetMetricsRequ return nil, errQueryNodeIsNotOnService(qn.id) } - return qn.client.GetMetrics(ctx, in) + return qn.client.GetMetrics(qn.ctx, in) } func (qn *queryNode) loadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest) error { @@ -594,7 +594,7 @@ func (qn *queryNode) loadSegments(ctx context.Context, in *querypb.LoadSegmentsR return errors.New("LoadSegments: queryNode is offline") } - status, err := qn.client.LoadSegments(ctx, in) + status, err := qn.client.LoadSegments(qn.ctx, in) if err != nil { return err } @@ -620,7 +620,7 @@ func (qn *queryNode) releaseSegments(ctx context.Context, in *querypb.ReleaseSeg return errors.New("ReleaseSegments: queryNode is offline") } - status, err := qn.client.ReleaseSegments(ctx, in) + status, err := qn.client.ReleaseSegments(qn.ctx, in) if err != nil { return err } diff --git a/internal/querycoord/task.go b/internal/querycoord/task.go index efda1a10eb..65a73ce703 100644 --- a/internal/querycoord/task.go +++ b/internal/querycoord/task.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" "sync" + "time" "github.com/golang/protobuf/proto" "go.uber.org/zap" @@ -38,6 +39,8 @@ import ( "github.com/opentracing/opentracing-go" ) +const timeoutForRPC = 10 * time.Second + const ( triggerTaskPrefix = "queryCoord-triggerTask" activeTaskPrefix = "queryCoord-activeTask" @@ -313,7 +316,9 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { }, CollectionID: collectionID, } - showPartitionResponse, err := lct.rootCoord.ShowPartitions(ctx, showPartitionRequest) + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + showPartitionResponse, err := lct.rootCoord.ShowPartitions(ctx2, showPartitionRequest) if err != nil { lct.setResultInfo(err) return err @@ -364,7 +369,11 @@ func (lct *loadCollectionTask) execute(ctx context.Context) error { CollectionID: collectionID, PartitionID: partitionID, } - recoveryInfo, err := lct.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfoRequest) + recoveryInfo, err := func() (*datapb.GetRecoveryInfoResponse, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + return lct.dataCoord.GetRecoveryInfo(ctx2, getRecoveryInfoRequest) + }() if err != nil { lct.setResultInfo(err) return err @@ -586,7 +595,9 @@ func (rct *releaseCollectionTask) execute(ctx context.Context) error { DbID: rct.DbID, CollectionID: rct.CollectionID, } - res, err := rct.rootCoord.ReleaseDQLMessageStream(rct.ctx, releaseDQLMessageStreamReq) + ctx2, cancel2 := context.WithTimeout(rct.ctx, timeoutForRPC) + defer cancel2() + res, err := rct.rootCoord.ReleaseDQLMessageStream(ctx2, releaseDQLMessageStreamReq) if res.ErrorCode != commonpb.ErrorCode_Success || err != nil { log.Warn("releaseCollectionTask: release collection end, releaseDQLMessageStream occur error", zap.Int64("collectionID", rct.CollectionID)) err = errors.New("rootCoord releaseDQLMessageStream failed") @@ -732,7 +743,11 @@ func (lpt *loadPartitionTask) execute(ctx context.Context) error { CollectionID: collectionID, PartitionID: partitionID, } - recoveryInfo, err := lpt.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfoRequest) + recoveryInfo, err := func() (*datapb.GetRecoveryInfoResponse, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + return lpt.dataCoord.GetRecoveryInfo(ctx2, getRecoveryInfoRequest) + }() if err != nil { lpt.setResultInfo(err) return err @@ -1542,7 +1557,11 @@ func (ht *handoffTask) execute(ctx context.Context) error { CollectionID: collectionID, PartitionID: partitionID, } - recoveryInfo, err := ht.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfoRequest) + recoveryInfo, err := func() (*datapb.GetRecoveryInfoResponse, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + return ht.dataCoord.GetRecoveryInfo(ctx2, getRecoveryInfoRequest) + }() if err != nil { ht.setResultInfo(err) return err @@ -1726,7 +1745,11 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { CollectionID: collectionID, PartitionID: partitionID, } - recoveryInfo, err := lbt.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfo) + recoveryInfo, err := func() (*datapb.GetRecoveryInfoResponse, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + return lbt.dataCoord.GetRecoveryInfo(ctx2, getRecoveryInfo) + }() if err != nil { lbt.setResultInfo(err) return err @@ -1921,7 +1944,11 @@ func (lbt *loadBalanceTask) execute(ctx context.Context) error { CollectionID: collectionID, PartitionID: partitionID, } - recoveryInfo, err := lbt.dataCoord.GetRecoveryInfo(ctx, getRecoveryInfoRequest) + recoveryInfo, err := func() (*datapb.GetRecoveryInfoResponse, error) { + ctx2, cancel2 := context.WithTimeout(ctx, timeoutForRPC) + defer cancel2() + return lbt.dataCoord.GetRecoveryInfo(ctx2, getRecoveryInfoRequest) + }() if err != nil { lbt.setResultInfo(err) return err diff --git a/internal/querynode/data_sync_service_test.go b/internal/querynode/data_sync_service_test.go index 3d0a3d3257..c2e4dabddc 100644 --- a/internal/querynode/data_sync_service_test.go +++ b/internal/querynode/data_sync_service_test.go @@ -146,7 +146,7 @@ func TestDataSyncService_collectionFlowGraphs(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) dataSyncService := newDataSyncService(ctx, streaming, historicalReplica, tSafe, fac) assert.NotNil(t, dataSyncService) @@ -193,7 +193,7 @@ func TestDataSyncService_partitionFlowGraphs(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) dataSyncService := newDataSyncService(ctx, streaming, historicalReplica, tSafe, fac) assert.NotNil(t, dataSyncService) @@ -242,7 +242,7 @@ func TestDataSyncService_removePartitionFlowGraphs(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) tSafe.addTSafe(defaultVChannel) dataSyncService := newDataSyncService(ctx, streaming, historicalReplica, tSafe, fac) diff --git a/internal/querynode/flow_graph_query_node_test.go b/internal/querynode/flow_graph_query_node_test.go index f45235f5a5..484482777e 100644 --- a/internal/querynode/flow_graph_query_node_test.go +++ b/internal/querynode/flow_graph_query_node_test.go @@ -29,7 +29,7 @@ func TestQueryNodeFlowGraph_consumerFlowGraph(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streamingReplica, err := genSimpleReplica() assert.NoError(t, err) @@ -62,7 +62,7 @@ func TestQueryNodeFlowGraph_seekQueryNodeFlowGraph(t *testing.T) { fac, err := genFactory() assert.NoError(t, err) - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) fg := newQueryNodeFlowGraph(ctx, loadTypeCollection, diff --git a/internal/querynode/flow_graph_service_time_node_test.go b/internal/querynode/flow_graph_service_time_node_test.go index 1df0ff167b..681fa4af5f 100644 --- a/internal/querynode/flow_graph_service_time_node_test.go +++ b/internal/querynode/flow_graph_service_time_node_test.go @@ -30,7 +30,7 @@ func TestServiceTimeNode_Operate(t *testing.T) { defer cancel() genServiceTimeNode := func() *serviceTimeNode { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) tSafe.addTSafe(defaultVChannel) fac, err := genFactory() diff --git a/internal/querynode/historical_test.go b/internal/querynode/historical_test.go index d730f8c753..8241c6fd69 100644 --- a/internal/querynode/historical_test.go +++ b/internal/querynode/historical_test.go @@ -100,7 +100,7 @@ func TestHistorical_Search(t *testing.T) { defer cancel() t.Run("test search", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -112,7 +112,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no collection - search partitions", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -127,7 +127,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no collection - search all collection", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -142,7 +142,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test load partition and partition has been released", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -161,7 +161,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test no partition in collection", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -178,7 +178,7 @@ func TestHistorical_Search(t *testing.T) { }) t.Run("test load collection partition released in collection", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/mock_test.go b/internal/querynode/mock_test.go index 05f0c972a2..541168932c 100644 --- a/internal/querynode/mock_test.go +++ b/internal/querynode/mock_test.go @@ -1315,7 +1315,7 @@ func genSimpleQueryNode(ctx context.Context) (*QueryNode, error) { node.etcdKV = etcdKV - node.tSafeReplica = newTSafeReplica() + node.tSafeReplica = newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, node.tSafeReplica) if err != nil { diff --git a/internal/querynode/plan_test.go b/internal/querynode/plan_test.go index 06da3982b6..3cdef5a56e 100644 --- a/internal/querynode/plan_test.go +++ b/internal/querynode/plan_test.go @@ -47,7 +47,7 @@ func TestPlan_createSearchPlanByExpr(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) historical, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/query_collection_test.go b/internal/querynode/query_collection_test.go index 202fa48b77..ae54696be4 100644 --- a/internal/querynode/query_collection_test.go +++ b/internal/querynode/query_collection_test.go @@ -40,7 +40,7 @@ import ( ) func genSimpleQueryCollection(ctx context.Context, cancel context.CancelFunc) (*queryCollection, error) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) historical, err := genSimpleHistorical(ctx, tSafe) if err != nil { return nil, err @@ -134,7 +134,7 @@ func TestQueryCollection_withoutVChannel(t *testing.T) { schema := genTestCollectionSchema(0, false, 2) historicalReplica := newCollectionReplica(etcdKV) - tsReplica := newTSafeReplica() + tsReplica := newTSafeReplica(ctx) streamingReplica := newCollectionReplica(etcdKV) historical := newHistorical(context.Background(), historicalReplica, etcdKV, tsReplica) diff --git a/internal/querynode/query_node.go b/internal/querynode/query_node.go index 47892121ad..a4c293bbff 100644 --- a/internal/querynode/query_node.go +++ b/internal/querynode/query_node.go @@ -191,7 +191,7 @@ func (node *QueryNode) Init() error { zap.Any("EtcdEndpoints", Params.EtcdEndpoints), zap.Any("MetaRootPath", Params.MetaRootPath), ) - node.tSafeReplica = newTSafeReplica() + node.tSafeReplica = newTSafeReplica(node.queryNodeLoopCtx) streamingReplica := newCollectionReplica(node.etcdKV) historicalReplica := newCollectionReplica(node.etcdKV) @@ -410,7 +410,7 @@ func (node *QueryNode) waitChangeInfo(segmentChangeInfos *querypb.SealedSegments return nil } - return retry.Do(context.TODO(), fn, retry.Attempts(50)) + return retry.Do(node.queryNodeLoopCtx, fn, retry.Attempts(50)) } // remove the segments since it's already compacted or balanced to other querynodes diff --git a/internal/querynode/query_node_test.go b/internal/querynode/query_node_test.go index f2da9712a2..5d441f1f82 100644 --- a/internal/querynode/query_node_test.go +++ b/internal/querynode/query_node_test.go @@ -192,7 +192,7 @@ func newQueryNodeMock() *QueryNode { panic(err) } svr := NewQueryNode(ctx, msFactory) - tsReplica := newTSafeReplica() + tsReplica := newTSafeReplica(ctx) streamingReplica := newCollectionReplica(etcdKV) historicalReplica := newCollectionReplica(etcdKV) svr.historical = newHistorical(svr.queryNodeLoopCtx, historicalReplica, etcdKV, tsReplica) diff --git a/internal/querynode/query_service_test.go b/internal/querynode/query_service_test.go index 4ae9cc604f..c2e9bd3487 100644 --- a/internal/querynode/query_service_test.go +++ b/internal/querynode/query_service_test.go @@ -221,7 +221,7 @@ func TestQueryService_addQueryCollection(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) his, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/segment_test.go b/internal/querynode/segment_test.go index 3aa1c50187..78f2517dc8 100644 --- a/internal/querynode/segment_test.go +++ b/internal/querynode/segment_test.go @@ -886,7 +886,7 @@ func TestSegment_indexInfoTest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) h, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) @@ -939,7 +939,7 @@ func TestSegment_indexInfoTest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) h, err := genSimpleHistorical(ctx, tSafe) assert.NoError(t, err) diff --git a/internal/querynode/streaming_test.go b/internal/querynode/streaming_test.go index b34a9e7a06..de1bd5e711 100644 --- a/internal/querynode/streaming_test.go +++ b/internal/querynode/streaming_test.go @@ -22,7 +22,7 @@ func TestStreaming_streaming(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -35,7 +35,7 @@ func TestStreaming_search(t *testing.T) { defer cancel() t.Run("test search", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -54,7 +54,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -73,7 +73,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition and loadCollection", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -99,7 +99,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test run empty partition and loadPartition", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -124,7 +124,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test no partitions in collection", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -146,7 +146,7 @@ func TestStreaming_search(t *testing.T) { }) t.Run("test search failed", func(t *testing.T) { - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() @@ -173,7 +173,7 @@ func TestStreaming_retrieve(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - tSafe := newTSafeReplica() + tSafe := newTSafeReplica(ctx) streaming, err := genSimpleStreaming(ctx, tSafe) assert.NoError(t, err) defer streaming.close() diff --git a/internal/querynode/tsafe_replica.go b/internal/querynode/tsafe_replica.go index 50cd5f89dd..f318868bd5 100644 --- a/internal/querynode/tsafe_replica.go +++ b/internal/querynode/tsafe_replica.go @@ -40,6 +40,7 @@ type tSafeRef struct { type tSafeReplica struct { mu sync.Mutex // guards tSafes tSafes map[Channel]*tSafeRef // map[vChannel]tSafeRef + ctx context.Context } func (t *tSafeReplica) getTSafe(vChannel Channel) (Timestamp, error) { @@ -80,10 +81,9 @@ func (t *tSafeReplica) getTSaferPrivate(vChannel Channel) (tSafer, error) { func (t *tSafeReplica) addTSafe(vChannel Channel) { t.mu.Lock() defer t.mu.Unlock() - ctx := context.Background() if _, ok := t.tSafes[vChannel]; !ok { t.tSafes[vChannel] = &tSafeRef{ - tSafer: newTSafe(ctx, vChannel), + tSafer: newTSafe(t.ctx, vChannel), ref: 1, } t.tSafes[vChannel].tSafer.start() @@ -149,9 +149,10 @@ func (t *tSafeReplica) registerTSafeWatcher(vChannel Channel, watcher *tSafeWatc return safer.registerTSafeWatcher(watcher) } -func newTSafeReplica() TSafeReplicaInterface { +func newTSafeReplica(ctx context.Context) TSafeReplicaInterface { var replica TSafeReplicaInterface = &tSafeReplica{ tSafes: make(map[string]*tSafeRef), + ctx: ctx, } return replica } diff --git a/internal/querynode/tsafe_replica_test.go b/internal/querynode/tsafe_replica_test.go index cb1ada9a88..3eb9f04fca 100644 --- a/internal/querynode/tsafe_replica_test.go +++ b/internal/querynode/tsafe_replica_test.go @@ -12,13 +12,14 @@ package querynode import ( + "context" "testing" "github.com/stretchr/testify/assert" ) func TestTSafeReplica_valid(t *testing.T) { - replica := newTSafeReplica() + replica := newTSafeReplica(context.Background()) replica.addTSafe(defaultVChannel) watcher := newTSafeWatcher() @@ -38,7 +39,7 @@ func TestTSafeReplica_valid(t *testing.T) { } func TestTSafeReplica_invalid(t *testing.T) { - replica := newTSafeReplica() + replica := newTSafeReplica(context.Background()) replica.addTSafe(defaultVChannel) watcher := newTSafeWatcher() diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go new file mode 100644 index 0000000000..b09252193d --- /dev/null +++ b/internal/util/grpcclient/client.go @@ -0,0 +1,230 @@ +package grpcclient + +import ( + "context" + "fmt" + "sync" + "time" + + grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" + grpcretry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + grpcopentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" + + "go.uber.org/zap" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/keepalive" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/trace" +) + +type GrpcClient interface { + SetRole(string) + GetRole() string + SetGetAddrFunc(func() (string, error)) + SetNewGrpcClientFunc(func(cc *grpc.ClientConn) interface{}) + GetGrpcClient(ctx context.Context) (interface{}, error) + ReCall(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) + Call(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) + Close() error +} + +type ClientBase struct { + getAddrFunc func() (string, error) + newGrpcClient func(cc *grpc.ClientConn) interface{} + + grpcClient interface{} + conn *grpc.ClientConn + grpcClientMtx sync.RWMutex + role string + ClientMaxSendSize int + ClientMaxRecvSize int +} + +func (c *ClientBase) SetRole(role string) { + c.role = role +} + +func (c *ClientBase) GetRole() string { + return c.role +} + +func (c *ClientBase) SetGetAddrFunc(f func() (string, error)) { + c.getAddrFunc = f +} + +func (c *ClientBase) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) interface{}) { + c.newGrpcClient = f +} + +func (c *ClientBase) GetGrpcClient(ctx context.Context) (interface{}, error) { + c.grpcClientMtx.RLock() + + if c.grpcClient != nil { + defer c.grpcClientMtx.RUnlock() + return c.grpcClient, nil + } + c.grpcClientMtx.RUnlock() + + c.grpcClientMtx.Lock() + defer c.grpcClientMtx.Unlock() + + if c.grpcClient != nil { + return c.grpcClient, nil + } + + err := c.connect(ctx, retry.Attempts(5)) + if err != nil { + return nil, err + } + + return c.grpcClient, nil +} + +func (c *ClientBase) resetConnection(client interface{}) { + c.grpcClientMtx.Lock() + defer c.grpcClientMtx.Unlock() + if c.grpcClient == nil { + return + } + + if client != c.grpcClient { + return + } + if c.conn != nil { + _ = c.conn.Close() + } + c.conn = nil + c.grpcClient = nil +} + +func (c *ClientBase) connect(ctx context.Context, retryOptions ...retry.Option) error { + var kacp = keepalive.ClientParameters{ + Time: 60 * time.Second, // send pings every 60 seconds if there is no activity + Timeout: 6 * time.Second, // wait 6 second for ping ack before considering the connection dead + PermitWithoutStream: true, // send pings even without active streams + } + + var err error + var addr string + connectServiceFunc := func() error { + addr, err = c.getAddrFunc() + if err != nil { + log.Debug(c.GetRole()+" client getAddr failed", zap.Error(err)) + return err + } + opts := trace.GetInterceptorOpts() + ctx1, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + conn, err2 := grpc.DialContext(ctx1, addr, + grpc.WithKeepaliveParams(kacp), + grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(c.ClientMaxRecvSize), + grpc.MaxCallSendMsgSize(c.ClientMaxSendSize)), + grpc.WithUnaryInterceptor( + grpcmiddleware.ChainUnaryClient( + grpcretry.UnaryClientInterceptor( + grpcretry.WithMax(3), + grpcretry.WithCodes(codes.Aborted, codes.Unavailable), + ), + grpcopentracing.UnaryClientInterceptor(opts...), + )), + grpc.WithStreamInterceptor( + grpcmiddleware.ChainStreamClient( + grpcretry.StreamClientInterceptor(grpcretry.WithMax(3), + grpcretry.WithCodes(codes.Aborted, codes.Unavailable), + ), + grpcopentracing.StreamClientInterceptor(opts...), + )), + ) + if err2 != nil { + return err2 + } + if c.conn != nil { + _ = c.conn.Close() + } + c.conn = conn + return nil + } + + err = retry.Do(ctx, connectServiceFunc, retryOptions...) + if err != nil { + log.Debug(c.GetRole()+" client try reconnect failed", zap.Error(err)) + return err + } + c.grpcClient = c.newGrpcClient(c.conn) + return nil +} + +func (c *ClientBase) callOnce(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { + client, err := c.GetGrpcClient(ctx) + if err != nil { + return nil, err + } + + ret, err2 := caller(client) + if err2 == nil { + return ret, nil + } + if err2 == context.Canceled || err2 == context.DeadlineExceeded { + return nil, err2 + } + + log.Debug(c.GetRole()+" ClientBase grpc error, start to reset connection", zap.Error(err2)) + + c.resetConnection(client) + return ret, err2 +} + +func (c *ClientBase) Call(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + + ret, err := c.callOnce(ctx, caller) + if err != nil { + traceErr := fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) + log.Error(c.GetRole()+" ClientBase Call grpc first call get error ", zap.Error(traceErr)) + return nil, traceErr + } + return ret, err +} + +func (c *ClientBase) ReCall(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + + ret, err := c.callOnce(ctx, caller) + if err == nil { + return ret, nil + } + + traceErr := fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) + log.Warn(c.GetRole()+" ClientBase ReCall grpc first call get error ", zap.Error(traceErr)) + + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + + ret, err = c.callOnce(ctx, caller) + if err != nil { + traceErr = fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) + log.Error(c.GetRole()+" ClientBase ReCall grpc second call get error ", zap.Error(traceErr)) + return nil, traceErr + } + return ret, err +} + +func (c *ClientBase) Close() error { + c.grpcClientMtx.Lock() + defer c.grpcClientMtx.Unlock() + if c.conn != nil { + return c.conn.Close() + } + return nil +} diff --git a/internal/util/grpcclient/client_test.go b/internal/util/grpcclient/client_test.go new file mode 100644 index 0000000000..1ef1bf304e --- /dev/null +++ b/internal/util/grpcclient/client_test.go @@ -0,0 +1,19 @@ +package grpcclient + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestClientBase_SetRole(t *testing.T) { + base := ClientBase{} + expect := "abc" + base.SetRole("abc") + assert.Equal(t, expect, base.GetRole()) +} + +func TestClientBase_GetRole(t *testing.T) { + base := ClientBase{} + assert.Equal(t, "", base.GetRole()) +} diff --git a/internal/util/mock/datacoord_client.go b/internal/util/mock/datacoord_client.go new file mode 100644 index 0000000000..a2d2638571 --- /dev/null +++ b/internal/util/mock/datacoord_client.go @@ -0,0 +1,119 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" +) + +type DataCoordClient struct { + Err error +} + +func (m *DataCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.Err +} + +func (m *DataCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *DataCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *DataCoordClient) Flush(ctx context.Context, in *datapb.FlushRequest, opts ...grpc.CallOption) (*datapb.FlushResponse, error) { + return &datapb.FlushResponse{}, m.Err +} + +func (m *DataCoordClient) AssignSegmentID(ctx context.Context, in *datapb.AssignSegmentIDRequest, opts ...grpc.CallOption) (*datapb.AssignSegmentIDResponse, error) { + return &datapb.AssignSegmentIDResponse{}, m.Err +} + +func (m *DataCoordClient) GetSegmentInfo(ctx context.Context, in *datapb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*datapb.GetSegmentInfoResponse, error) { + return &datapb.GetSegmentInfoResponse{}, m.Err +} + +func (m *DataCoordClient) GetSegmentStates(ctx context.Context, in *datapb.GetSegmentStatesRequest, opts ...grpc.CallOption) (*datapb.GetSegmentStatesResponse, error) { + return &datapb.GetSegmentStatesResponse{}, m.Err +} + +func (m *DataCoordClient) GetInsertBinlogPaths(ctx context.Context, in *datapb.GetInsertBinlogPathsRequest, opts ...grpc.CallOption) (*datapb.GetInsertBinlogPathsResponse, error) { + return &datapb.GetInsertBinlogPathsResponse{}, m.Err +} + +func (m *DataCoordClient) GetCollectionStatistics(ctx context.Context, in *datapb.GetCollectionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetCollectionStatisticsResponse, error) { + return &datapb.GetCollectionStatisticsResponse{}, m.Err +} + +func (m *DataCoordClient) GetPartitionStatistics(ctx context.Context, in *datapb.GetPartitionStatisticsRequest, opts ...grpc.CallOption) (*datapb.GetPartitionStatisticsResponse, error) { + return &datapb.GetPartitionStatisticsResponse{}, m.Err +} + +func (m *DataCoordClient) GetSegmentInfoChannel(ctx context.Context, in *datapb.GetSegmentInfoChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *DataCoordClient) SaveBinlogPaths(ctx context.Context, in *datapb.SaveBinlogPathsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *DataCoordClient) GetRecoveryInfo(ctx context.Context, in *datapb.GetRecoveryInfoRequest, opts ...grpc.CallOption) (*datapb.GetRecoveryInfoResponse, error) { + return &datapb.GetRecoveryInfoResponse{}, m.Err +} + +func (m *DataCoordClient) GetFlushedSegments(ctx context.Context, in *datapb.GetFlushedSegmentsRequest, opts ...grpc.CallOption) (*datapb.GetFlushedSegmentsResponse, error) { + return &datapb.GetFlushedSegmentsResponse{}, m.Err +} + +func (m *DataCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.Err +} + +func (m *DataCoordClient) CompleteCompaction(ctx context.Context, req *datapb.CompactionResult, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *DataCoordClient) ManualCompaction(ctx context.Context, in *milvuspb.ManualCompactionRequest, opts ...grpc.CallOption) (*milvuspb.ManualCompactionResponse, error) { + return &milvuspb.ManualCompactionResponse{}, m.Err +} + +func (m *DataCoordClient) GetCompactionState(ctx context.Context, in *milvuspb.GetCompactionStateRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionStateResponse, error) { + return &milvuspb.GetCompactionStateResponse{}, m.Err +} + +func (m *DataCoordClient) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest, opts ...grpc.CallOption) (*milvuspb.GetCompactionPlansResponse, error) { + return &milvuspb.GetCompactionPlansResponse{}, m.Err +} + +func (m *DataCoordClient) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest, opts ...grpc.CallOption) (*datapb.WatchChannelsResponse, error) { + return &datapb.WatchChannelsResponse{}, m.Err +} +func (m *DataCoordClient) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest, opts ...grpc.CallOption) (*milvuspb.GetFlushStateResponse, error) { + return &milvuspb.GetFlushStateResponse{}, m.Err +} + +func (m *DataCoordClient) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest, opts ...grpc.CallOption) (*datapb.DropVirtualChannelResponse, error) { + return &datapb.DropVirtualChannelResponse{}, m.Err +} diff --git a/internal/util/mock/datanode_client.go b/internal/util/mock/datanode_client.go new file mode 100644 index 0000000000..263bb95e68 --- /dev/null +++ b/internal/util/mock/datanode_client.go @@ -0,0 +1,56 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" +) + +type DataNodeClient struct { + Err error +} + +func (m *DataNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.Err +} + +func (m *DataNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *DataNodeClient) WatchDmChannels(ctx context.Context, in *datapb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *DataNodeClient) FlushSegments(ctx context.Context, in *datapb.FlushSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *DataNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.Err +} + +func (m *DataNodeClient) Compaction(ctx context.Context, req *datapb.CompactionPlan, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} diff --git a/internal/util/mock/grpcclient.go b/internal/util/mock/grpcclient.go new file mode 100644 index 0000000000..f26d31da2d --- /dev/null +++ b/internal/util/mock/grpcclient.go @@ -0,0 +1,153 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + "fmt" + "sync" + + "go.uber.org/zap" + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/log" + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/trace" +) + +type ClientBase struct { + getAddrFunc func() (string, error) + newGrpcClient func(cc *grpc.ClientConn) interface{} + + grpcClient interface{} + conn *grpc.ClientConn + grpcClientMtx sync.RWMutex + GetGrpcClientErr error + role string +} + +func (c *ClientBase) SetGetAddrFunc(f func() (string, error)) { + c.getAddrFunc = f +} + +func (c *ClientBase) GetRole() string { + return c.role +} + +func (c *ClientBase) SetRole(role string) { + c.role = role +} + +func (c *ClientBase) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) interface{}) { + c.newGrpcClient = f +} + +func (c *ClientBase) GetGrpcClient(ctx context.Context) (interface{}, error) { + c.grpcClientMtx.RLock() + defer c.grpcClientMtx.RUnlock() + c.connect(ctx) + return c.grpcClient, c.GetGrpcClientErr +} + +func (c *ClientBase) resetConnection(client interface{}) { + c.grpcClientMtx.Lock() + defer c.grpcClientMtx.Unlock() + if c.grpcClient == nil { + return + } + + if client != c.grpcClient { + return + } + if c.conn != nil { + _ = c.conn.Close() + } + c.conn = nil + c.grpcClient = nil +} + +func (c *ClientBase) connect(ctx context.Context, retryOptions ...retry.Option) error { + c.grpcClient = c.newGrpcClient(c.conn) + return nil +} + +func (c *ClientBase) callOnce(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { + client, err := c.GetGrpcClient(ctx) + if err != nil { + return nil, err + } + + ret, err2 := caller(client) + if err2 == nil { + return ret, nil + } + if err2 == context.Canceled || err2 == context.DeadlineExceeded { + return nil, err2 + } + + c.resetConnection(client) + return ret, err2 +} + +func (c *ClientBase) Call(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + + ret, err := c.callOnce(ctx, caller) + if err != nil { + traceErr := fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) + log.Error("ClientBase Call grpc first call get error ", zap.Error(traceErr)) + return nil, traceErr + } + return ret, err +} + +func (c *ClientBase) ReCall(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + ret, err := c.callOnce(ctx, caller) + if err == nil { + return ret, nil + } + + traceErr := fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) + log.Warn("ClientBase client grpc first call get error ", zap.Error(traceErr)) + + if !funcutil.CheckCtxValid(ctx) { + return nil, ctx.Err() + } + + ret, err = c.callOnce(ctx, caller) + if err != nil { + traceErr = fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) + log.Error("ClientBase client grpc second call get error ", zap.Error(traceErr)) + return nil, traceErr + } + return ret, err +} + +func (c *ClientBase) Close() error { + c.grpcClientMtx.Lock() + defer c.grpcClientMtx.Unlock() + if c.conn != nil { + return c.conn.Close() + } + return nil +} diff --git a/internal/util/mock/indexnode_client.go b/internal/util/mock/indexnode_client.go new file mode 100644 index 0000000000..91ac33fe5d --- /dev/null +++ b/internal/util/mock/indexnode_client.go @@ -0,0 +1,52 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/indexpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" +) + +type IndexNodeClient struct { + Err error +} + +func (m *IndexNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.Err +} + +func (m *IndexNodeClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *IndexNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *IndexNodeClient) CreateIndex(ctx context.Context, in *indexpb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *IndexNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.Err +} diff --git a/internal/util/mock/proxy_client.go b/internal/util/mock/proxy_client.go new file mode 100644 index 0000000000..cdc7b03c5b --- /dev/null +++ b/internal/util/mock/proxy_client.go @@ -0,0 +1,52 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/proxypb" +) + +type ProxyClient struct { + Err error +} + +func (m *ProxyClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.Err +} + +func (m *ProxyClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *ProxyClient) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *ProxyClient) GetDdChannel(ctx context.Context, in *internalpb.GetDdChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *ProxyClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} diff --git a/internal/util/mock/querycoord_client.go b/internal/util/mock/querycoord_client.go new file mode 100644 index 0000000000..f647fc1712 --- /dev/null +++ b/internal/util/mock/querycoord_client.go @@ -0,0 +1,88 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/querypb" +) + +type QueryCoordClient struct { + Err error +} + +func (m *QueryCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.Err +} + +func (m *QueryCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *QueryCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *QueryCoordClient) ShowCollections(ctx context.Context, in *querypb.ShowCollectionsRequest, opts ...grpc.CallOption) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{}, m.Err +} + +func (m *QueryCoordClient) ShowPartitions(ctx context.Context, in *querypb.ShowPartitionsRequest, opts ...grpc.CallOption) (*querypb.ShowPartitionsResponse, error) { + return &querypb.ShowPartitionsResponse{}, m.Err +} + +func (m *QueryCoordClient) LoadPartitions(ctx context.Context, in *querypb.LoadPartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryCoordClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryCoordClient) LoadCollection(ctx context.Context, in *querypb.LoadCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryCoordClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryCoordClient) CreateQueryChannel(ctx context.Context, in *querypb.CreateQueryChannelRequest, opts ...grpc.CallOption) (*querypb.CreateQueryChannelResponse, error) { + return &querypb.CreateQueryChannelResponse{}, m.Err +} + +func (m *QueryCoordClient) GetPartitionStates(ctx context.Context, in *querypb.GetPartitionStatesRequest, opts ...grpc.CallOption) (*querypb.GetPartitionStatesResponse, error) { + return &querypb.GetPartitionStatesResponse{}, m.Err +} + +func (m *QueryCoordClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + return &querypb.GetSegmentInfoResponse{}, m.Err +} + +func (m *QueryCoordClient) LoadBalance(ctx context.Context, in *querypb.LoadBalanceRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.Err +} diff --git a/internal/util/mock/querynode_client.go b/internal/util/mock/querynode_client.go new file mode 100644 index 0000000000..3a877cafe9 --- /dev/null +++ b/internal/util/mock/querynode_client.go @@ -0,0 +1,83 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/querypb" +) + +type QueryNodeClient struct { + Err error +} + +func (m *QueryNodeClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.Err +} + +func (m *QueryNodeClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} +func (m *QueryNodeClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *QueryNodeClient) AddQueryChannel(ctx context.Context, in *querypb.AddQueryChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) RemoveQueryChannel(ctx context.Context, in *querypb.RemoveQueryChannelRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) WatchDmChannels(ctx context.Context, in *querypb.WatchDmChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) WatchDeltaChannels(ctx context.Context, in *querypb.WatchDeltaChannelsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) LoadSegments(ctx context.Context, in *querypb.LoadSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) ReleaseCollection(ctx context.Context, in *querypb.ReleaseCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) ReleasePartitions(ctx context.Context, in *querypb.ReleasePartitionsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) ReleaseSegments(ctx context.Context, in *querypb.ReleaseSegmentsRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *QueryNodeClient) GetSegmentInfo(ctx context.Context, in *querypb.GetSegmentInfoRequest, opts ...grpc.CallOption) (*querypb.GetSegmentInfoResponse, error) { + return &querypb.GetSegmentInfoResponse{}, m.Err +} + +func (m *QueryNodeClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.Err +} diff --git a/internal/util/mock/rootcoord_client.go b/internal/util/mock/rootcoord_client.go new file mode 100644 index 0000000000..fddf1698b6 --- /dev/null +++ b/internal/util/mock/rootcoord_client.go @@ -0,0 +1,136 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mock + +import ( + "context" + + "google.golang.org/grpc" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/datapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/proxypb" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" +) + +type RootCoordClient struct { + Err error +} + +func (m *RootCoordClient) GetComponentStates(ctx context.Context, in *internalpb.GetComponentStatesRequest, opts ...grpc.CallOption) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{}, m.Err +} +func (m *RootCoordClient) GetTimeTickChannel(ctx context.Context, in *internalpb.GetTimeTickChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} +func (m *RootCoordClient) GetStatisticsChannel(ctx context.Context, in *internalpb.GetStatisticsChannelRequest, opts ...grpc.CallOption) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{}, m.Err +} + +func (m *RootCoordClient) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { + return &milvuspb.BoolResponse{}, m.Err +} + +func (m *RootCoordClient) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) { + return &milvuspb.DescribeCollectionResponse{}, m.Err +} + +func (m *RootCoordClient) CreateAlias(ctx context.Context, in *milvuspb.CreateAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) DropAlias(ctx context.Context, in *milvuspb.DropAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) AlterAlias(ctx context.Context, in *milvuspb.AlterAliasRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowCollectionsResponse, error) { + return &milvuspb.ShowCollectionsResponse{}, m.Err +} + +func (m *RootCoordClient) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest, opts ...grpc.CallOption) (*milvuspb.BoolResponse, error) { + return &milvuspb.BoolResponse{}, m.Err +} + +func (m *RootCoordClient) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest, opts ...grpc.CallOption) (*milvuspb.ShowPartitionsResponse, error) { + return &milvuspb.ShowPartitionsResponse{}, m.Err +} + +func (m *RootCoordClient) DescribeSegment(ctx context.Context, in *milvuspb.DescribeSegmentRequest, opts ...grpc.CallOption) (*milvuspb.DescribeSegmentResponse, error) { + return &milvuspb.DescribeSegmentResponse{}, m.Err +} + +func (m *RootCoordClient) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest, opts ...grpc.CallOption) (*milvuspb.ShowSegmentsResponse, error) { + return &milvuspb.ShowSegmentsResponse{}, m.Err +} + +func (m *RootCoordClient) CreateIndex(ctx context.Context, in *milvuspb.CreateIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) DescribeIndex(ctx context.Context, in *milvuspb.DescribeIndexRequest, opts ...grpc.CallOption) (*milvuspb.DescribeIndexResponse, error) { + return &milvuspb.DescribeIndexResponse{}, m.Err +} + +func (m *RootCoordClient) DropIndex(ctx context.Context, in *milvuspb.DropIndexRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocTimestampResponse, error) { + return &rootcoordpb.AllocTimestampResponse{}, m.Err +} + +func (m *RootCoordClient) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest, opts ...grpc.CallOption) (*rootcoordpb.AllocIDResponse, error) { + return &rootcoordpb.AllocIDResponse{}, m.Err +} + +func (m *RootCoordClient) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) ReleaseDQLMessageStream(ctx context.Context, in *proxypb.ReleaseDQLMessageStreamRequest, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) SegmentFlushCompleted(ctx context.Context, in *datapb.SegmentFlushCompletedMsg, opts ...grpc.CallOption) (*commonpb.Status, error) { + return &commonpb.Status{}, m.Err +} + +func (m *RootCoordClient) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest, opts ...grpc.CallOption) (*milvuspb.GetMetricsResponse, error) { + return &milvuspb.GetMetricsResponse{}, m.Err +}