From 0a83cdfe1bd7ca064e9be5ca5c87f5b2bc92c24b Mon Sep 17 00:00:00 2001 From: Enwei Jiao Date: Mon, 17 Oct 2022 20:05:26 +0800 Subject: [PATCH] Refactor grpclient, make it generic (#19791) Signed-off-by: Enwei Jiao Signed-off-by: Enwei Jiao --- .../distributed/datacoord/client/client.go | 134 +++++++------- .../datacoord/client/client_test.go | 15 +- .../distributed/datanode/client/client.go | 54 +++--- .../datanode/client/client_test.go | 13 +- .../distributed/indexcoord/client/client.go | 50 +++--- .../distributed/indexnode/client/client.go | 38 ++-- .../indexnode/client/client_test.go | 12 +- internal/distributed/proxy/client/client.go | 38 ++-- .../distributed/proxy/client/client_test.go | 13 +- .../distributed/querycoord/client/client.go | 70 ++++---- .../querycoord/client/client_test.go | 15 +- .../distributed/querynode/client/client.go | 78 ++++---- .../querynode/client/client_test.go | 17 +- .../distributed/rootcoord/client/client.go | 166 +++++++++--------- .../rootcoord/client/client_test.go | 13 +- internal/util/generic/generic.go | 31 ++++ internal/util/grpcclient/client.go | 70 ++++---- internal/util/grpcclient/client_test.go | 45 +++-- internal/util/mock/grpcclient.go | 43 ++--- tests/python_client/testcases/test_utility.py | 4 +- 20 files changed, 478 insertions(+), 441 deletions(-) create mode 100644 internal/util/generic/generic.go diff --git a/internal/distributed/datacoord/client/client.go b/internal/distributed/datacoord/client/client.go index ddea9a132f..dad7f53db9 100644 --- a/internal/distributed/datacoord/client/client.go +++ b/internal/distributed/datacoord/client/client.go @@ -43,7 +43,7 @@ var _ types.DataCoord = (*Client)(nil) // Client is the datacoord grpc client type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[datapb.DataCoordClient] sess *sessionutil.Session } @@ -57,7 +57,7 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) ( } ClientParams.InitOnce(typeutil.DataCoordRole) client := &Client{ - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[datapb.DataCoordClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -78,7 +78,7 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) ( return client, nil } -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) datapb.DataCoordClient { return datapb.NewDataCoordClient(cc) } @@ -119,11 +119,11 @@ func (c *Client) Register() error { // GetComponentStates calls DataCoord GetComponentStates services func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -133,11 +133,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta // GetTimeTickChannel return the name of time tick channel. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -147,11 +147,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel return the name of statistics channel. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.Call(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -161,11 +161,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // Flush flushes a collection's data func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.FlushResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).Flush(ctx, req) + return client.Flush(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -187,11 +187,11 @@ func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*datapb.F // if the VChannel is newly used, `WatchDmlChannels` will be invoked to notify a `DataNode`(selected by policy) to watch it // if there is anything make the allocation impossible, the response will not contain the corresponding result func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).AssignSegmentID(ctx, req) + return client.AssignSegmentID(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -209,11 +209,11 @@ func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI // otherwise the Segment State and Start position information will be returned // error is returned only when some communication issue occurs func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetSegmentStates(ctx, req) + return client.GetSegmentStates(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -230,11 +230,11 @@ func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentSta // and corresponding binlog path list // error is returned only when some communication issue occurs func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetInsertBinlogPaths(ctx, req) + return client.GetInsertBinlogPaths(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -251,11 +251,11 @@ func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsert // only row count for now // error is returned only when some communication issue occurs func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetCollectionStatistics(ctx, req) + return client.GetCollectionStatistics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -272,11 +272,11 @@ func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCol // only row count for now // error is returned only when some communication issue occurs func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetPartitionStatistics(ctx, req) + return client.GetPartitionStatistics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -287,11 +287,11 @@ func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPart // GetSegmentInfoChannel DEPRECATED // legacy api to get SegmentInfo Channel name func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) + return client.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -307,11 +307,11 @@ func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringRes // response struct `GetSegmentInfoResponse` contains the list of segment info // error is returned only when some communication issue occurs func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetSegmentInfo(ctx, req) + return client.GetSegmentInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -333,11 +333,11 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoR // if the constraint is broken, the checkpoint position will not be monotonically increasing and the integrity will be compromised func (c *Client) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPathsRequest) (*commonpb.Status, error) { // use Call here on purpose - ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.Call(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).SaveBinlogPaths(ctx, req) + return client.SaveBinlogPaths(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -353,11 +353,11 @@ func (c *Client) SaveBinlogPaths(ctx context.Context, req *datapb.SaveBinlogPath // response struct `GetRecoveryInfoResponse` contains the list of segments info and corresponding vchannel info // error is returned only when some communication issue occurs func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInfoRequest) (*datapb.GetRecoveryInfoResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetRecoveryInfo(ctx, req) + return client.GetRecoveryInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -374,11 +374,11 @@ func (c *Client) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf // response struct `GetFlushedSegmentsResponse` contains flushed segment id list // error is returned only when some communication issue occurs func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedSegmentsRequest) (*datapb.GetFlushedSegmentsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetFlushedSegments(ctx, req) + return client.GetFlushedSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -395,11 +395,11 @@ func (c *Client) GetFlushedSegments(ctx context.Context, req *datapb.GetFlushedS // response struct `GetSegmentsByStatesResponse` contains segment id list // error is returned only when some communication issue occurs func (c *Client) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegmentsByStatesRequest) (*datapb.GetSegmentsByStatesResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetSegmentsByStates(ctx, req) + return client.GetSegmentsByStates(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -409,11 +409,11 @@ func (c *Client) GetSegmentsByStates(ctx context.Context, req *datapb.GetSegment // ShowConfigurations gets specified configurations para of DataCoord func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).ShowConfigurations(ctx, req) + return client.ShowConfigurations(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -424,11 +424,11 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // GetMetrics gets all metrics of datacoord func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetMetrics(ctx, req) + return client.GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -438,11 +438,11 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest // ManualCompaction triggers a compaction for a collection func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompactionRequest) (*milvuspb.ManualCompactionResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).ManualCompaction(ctx, req) + return client.ManualCompaction(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -452,11 +452,11 @@ func (c *Client) ManualCompaction(ctx context.Context, req *milvuspb.ManualCompa // GetCompactionState gets the state of a compaction func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompactionStateRequest) (*milvuspb.GetCompactionStateResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetCompactionState(ctx, req) + return client.GetCompactionState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -466,11 +466,11 @@ func (c *Client) GetCompactionState(ctx context.Context, req *milvuspb.GetCompac // GetCompactionStateWithPlans gets the state of a compaction by plan func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetCompactionStateWithPlans(ctx, req) + return client.GetCompactionStateWithPlans(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -480,11 +480,11 @@ func (c *Client) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb. // WatchChannels notifies DataCoord to watch vchannels of a collection func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsRequest) (*datapb.WatchChannelsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).WatchChannels(ctx, req) + return client.WatchChannels(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -494,11 +494,11 @@ func (c *Client) WatchChannels(ctx context.Context, req *datapb.WatchChannelsReq // GetFlushState gets the flush state of multiple segments func (c *Client) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateRequest) (*milvuspb.GetFlushStateResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).GetFlushState(ctx, req) + return client.GetFlushState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -508,11 +508,11 @@ func (c *Client) GetFlushState(ctx context.Context, req *milvuspb.GetFlushStateR // DropVirtualChannel drops virtual channel in datacoord. func (c *Client) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtualChannelRequest) (*datapb.DropVirtualChannelResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).DropVirtualChannel(ctx, req) + return client.DropVirtualChannel(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -522,11 +522,11 @@ func (c *Client) DropVirtualChannel(ctx context.Context, req *datapb.DropVirtual // SetSegmentState sets the state of a given segment. func (c *Client) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStateRequest) (*datapb.SetSegmentStateResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).SetSegmentState(ctx, req) + return client.SetSegmentState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -536,11 +536,11 @@ func (c *Client) SetSegmentState(ctx context.Context, req *datapb.SetSegmentStat // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*datapb.ImportTaskResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).Import(ctx, req) + return client.Import(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -550,11 +550,11 @@ func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*da // UpdateSegmentStatistics is the client side caller of UpdateSegmentStatistics. func (c *Client) UpdateSegmentStatistics(ctx context.Context, req *datapb.UpdateSegmentStatisticsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).UpdateSegmentStatistics(ctx, req) + return client.UpdateSegmentStatistics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -564,11 +564,11 @@ func (c *Client) UpdateSegmentStatistics(ctx context.Context, req *datapb.Update // AcquireSegmentLock acquire the reference lock of the segments. func (c *Client) AcquireSegmentLock(ctx context.Context, req *datapb.AcquireSegmentLockRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).AcquireSegmentLock(ctx, req) + return client.AcquireSegmentLock(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -578,11 +578,11 @@ func (c *Client) AcquireSegmentLock(ctx context.Context, req *datapb.AcquireSegm // ReleaseSegmentLock release the reference lock of the segments. func (c *Client) ReleaseSegmentLock(ctx context.Context, req *datapb.ReleaseSegmentLockRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).ReleaseSegmentLock(ctx, req) + return client.ReleaseSegmentLock(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -592,11 +592,11 @@ func (c *Client) ReleaseSegmentLock(ctx context.Context, req *datapb.ReleaseSegm // SaveImportSegment is the DataCoord client side code for SaveImportSegment call. func (c *Client) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSegmentRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).SaveImportSegment(ctx, req) + return client.SaveImportSegment(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -605,11 +605,11 @@ func (c *Client) SaveImportSegment(ctx context.Context, req *datapb.SaveImportSe } func (c *Client) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsImportingStateRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).UnsetIsImportingState(ctx, req) + return client.UnsetIsImportingState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -618,11 +618,11 @@ func (c *Client) UnsetIsImportingState(ctx context.Context, req *datapb.UnsetIsI } func (c *Client) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmentsDroppedRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).MarkSegmentsDropped(ctx, req) + return client.MarkSegmentsDropped(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -632,11 +632,11 @@ func (c *Client) MarkSegmentsDropped(ctx context.Context, req *datapb.MarkSegmen // BroadcastAlteredCollection is the DataCoord client side code for BroadcastAlteredCollection call. func (c *Client) BroadcastAlteredCollection(ctx context.Context, req *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataCoordClient).BroadcastAlteredCollection(ctx, req) + return client.BroadcastAlteredCollection(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/datacoord/client/client_test.go b/internal/distributed/datacoord/client/client_test.go index 8b05d6c7d5..590d73b52c 100644 --- a/internal/distributed/datacoord/client/client_test.go +++ b/internal/distributed/datacoord/client/client_test.go @@ -21,6 +21,7 @@ import ( "errors" "testing" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proxy" "github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/mock" @@ -48,7 +49,7 @@ func Test_NewClient(t *testing.T) { assert.Nil(t, err) checkFunc := func(retNotNil bool) { - retCheck := func(notNil bool, ret interface{}, err error) { + retCheck := func(notNil bool, ret any, err error) { if notNil { assert.NotNil(t, ret) assert.Nil(t, err) @@ -151,11 +152,11 @@ func Test_NewClient(t *testing.T) { } } - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[datapb.DataCoordClient]{ GetGrpcClientErr: errors.New("dummy"), } - newFunc1 := func(cc *grpc.ClientConn) interface{} { + newFunc1 := func(cc *grpc.ClientConn) datapb.DataCoordClient { return &mock.GrpcDataCoordClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc1) @@ -167,10 +168,10 @@ func Test_NewClient(t *testing.T) { assert.Nil(t, ret) assert.NotNil(t, err) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[datapb.DataCoordClient]{ GetGrpcClientErr: nil, } - newFunc2 := func(cc *grpc.ClientConn) interface{} { + newFunc2 := func(cc *grpc.ClientConn) datapb.DataCoordClient { return &mock.GrpcDataCoordClient{Err: errors.New("dummy")} } client.grpcClient.SetNewGrpcClientFunc(newFunc2) @@ -181,10 +182,10 @@ func Test_NewClient(t *testing.T) { assert.Nil(t, ret) assert.NotNil(t, err) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[datapb.DataCoordClient]{ GetGrpcClientErr: nil, } - newFunc3 := func(cc *grpc.ClientConn) interface{} { + newFunc3 := func(cc *grpc.ClientConn) datapb.DataCoordClient { return &mock.GrpcDataCoordClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc3) diff --git a/internal/distributed/datanode/client/client.go b/internal/distributed/datanode/client/client.go index c096e5de65..c501fa7b48 100644 --- a/internal/distributed/datanode/client/client.go +++ b/internal/distributed/datanode/client/client.go @@ -35,7 +35,7 @@ var ClientParams paramtable.GrpcClientConfig // Client is the grpc client for DataNode type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[datapb.DataNodeClient] addr string } @@ -47,7 +47,7 @@ func NewClient(ctx context.Context, addr string) (*Client, error) { ClientParams.InitOnce(typeutil.DataNodeRole) client := &Client{ addr: addr, - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[datapb.DataNodeClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -89,7 +89,7 @@ func (c *Client) Register() error { return nil } -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) datapb.DataNodeClient { return datapb.NewDataNodeClient(cc) } @@ -99,11 +99,11 @@ func (c *Client) getAddr() (string, error) { // GetComponentStates returns ComponentStates func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -114,11 +114,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta // GetStatisticsChannel return the statistics channel in string // Statistics channel contains statistics infos of query nodes, such as segment infos, memory infos func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -129,11 +129,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // Deprecated // WatchDmChannels create consumers on dmChannels to reveive Incremental data func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannelsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).WatchDmChannels(ctx, req) + return client.WatchDmChannels(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -150,11 +150,11 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *datapb.WatchDmChannel // Return Success code in status and trigers background flush: // Log an info log if a segment is under flushing func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).FlushSegments(ctx, req) + return client.FlushSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -164,11 +164,11 @@ func (c *Client) FlushSegments(ctx context.Context, req *datapb.FlushSegmentsReq // ShowConfigurations gets specified configurations para of DataNode func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).ShowConfigurations(ctx, req) + return client.ShowConfigurations(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -179,11 +179,11 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // GetMetrics returns metrics func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).GetMetrics(ctx, req) + return client.GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -193,11 +193,11 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest // Compaction return compaction by given plan func (c *Client) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).Compaction(ctx, req) + return client.Compaction(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -206,11 +206,11 @@ func (c *Client) Compaction(ctx context.Context, req *datapb.CompactionPlan) (*c } func (c *Client) GetCompactionState(ctx context.Context, req *datapb.CompactionStateRequest) (*datapb.CompactionStateResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).GetCompactionState(ctx, req) + return client.GetCompactionState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -220,11 +220,11 @@ func (c *Client) GetCompactionState(ctx context.Context, req *datapb.CompactionS // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).Import(ctx, req) + return client.Import(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -233,11 +233,11 @@ func (c *Client) Import(ctx context.Context, req *datapb.ImportTaskRequest) (*co } func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegmentStatsRequest) (*datapb.ResendSegmentStatsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).ResendSegmentStats(ctx, req) + return client.ResendSegmentStats(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -247,11 +247,11 @@ func (c *Client) ResendSegmentStats(ctx context.Context, req *datapb.ResendSegme // AddImportSegment is the DataNode client side code for AddImportSegment call. func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegmentRequest) (*datapb.AddImportSegmentResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).AddImportSegment(ctx, req) + return client.AddImportSegment(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -261,11 +261,11 @@ func (c *Client) AddImportSegment(ctx context.Context, req *datapb.AddImportSegm // SyncSegments is the DataNode client side code for SyncSegments call. func (c *Client) SyncSegments(ctx context.Context, req *datapb.SyncSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client datapb.DataNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(datapb.DataNodeClient).SyncSegments(ctx, req) + return client.SyncSegments(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/datanode/client/client_test.go b/internal/distributed/datanode/client/client_test.go index 54aec65e8a..bfe84a60bf 100644 --- a/internal/distributed/datanode/client/client_test.go +++ b/internal/distributed/datanode/client/client_test.go @@ -21,6 +21,7 @@ import ( "errors" "testing" + "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/util/mock" "google.golang.org/grpc" @@ -93,22 +94,22 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r11, err) } - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[datapb.DataNodeClient]{ GetGrpcClientErr: errors.New("dummy"), } - newFunc1 := func(cc *grpc.ClientConn) interface{} { + newFunc1 := func(cc *grpc.ClientConn) datapb.DataNodeClient { return &mock.GrpcDataNodeClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc1) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[datapb.DataNodeClient]{ GetGrpcClientErr: nil, } - newFunc2 := func(cc *grpc.ClientConn) interface{} { + newFunc2 := func(cc *grpc.ClientConn) datapb.DataNodeClient { return &mock.GrpcDataNodeClient{Err: errors.New("dummy")} } @@ -116,11 +117,11 @@ func Test_NewClient(t *testing.T) { checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[datapb.DataNodeClient]{ GetGrpcClientErr: nil, } - newFunc3 := func(cc *grpc.ClientConn) interface{} { + newFunc3 := func(cc *grpc.ClientConn) datapb.DataNodeClient { return &mock.GrpcDataNodeClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc3) diff --git a/internal/distributed/indexcoord/client/client.go b/internal/distributed/indexcoord/client/client.go index 8c18b009a7..f95f35bd47 100644 --- a/internal/distributed/indexcoord/client/client.go +++ b/internal/distributed/indexcoord/client/client.go @@ -40,7 +40,7 @@ var ClientParams paramtable.GrpcClientConfig // Client is the grpc client of IndexCoord. type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[indexpb.IndexCoordClient] sess *sessionutil.Session } @@ -54,7 +54,7 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) ( } ClientParams.InitOnce(typeutil.IndexCoordRole) client := &Client{ - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[indexpb.IndexCoordClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -112,17 +112,17 @@ func (c *Client) getIndexCoordAddr() (string, error) { } // newGrpcClient create a new grpc client of IndexCoord. -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) indexpb.IndexCoordClient { return indexpb.NewIndexCoordClient(cc) } // GetComponentStates gets the component states of IndexCoord. func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -132,11 +132,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta // GetStatisticsChannel gets the statistics channel of IndexCoord. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -146,11 +146,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // CreateIndex sends the build index request to IndexCoord. func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).CreateIndex(ctx, req) + return client.CreateIndex(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -160,11 +160,11 @@ func (c *Client) CreateIndex(ctx context.Context, req *indexpb.CreateIndexReques // GetIndexState gets the index states from IndexCoord. func (c *Client) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRequest) (*indexpb.GetIndexStateResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).GetIndexState(ctx, req) + return client.GetIndexState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -174,11 +174,11 @@ func (c *Client) GetIndexState(ctx context.Context, req *indexpb.GetIndexStateRe // GetSegmentIndexState gets the index states from IndexCoord. func (c *Client) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegmentIndexStateRequest) (*indexpb.GetSegmentIndexStateResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).GetSegmentIndexState(ctx, req) + return client.GetSegmentIndexState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -188,11 +188,11 @@ func (c *Client) GetSegmentIndexState(ctx context.Context, req *indexpb.GetSegme // GetIndexInfos gets the index file paths from IndexCoord. func (c *Client) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoRequest) (*indexpb.GetIndexInfoResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).GetIndexInfos(ctx, req) + return client.GetIndexInfos(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -202,11 +202,11 @@ func (c *Client) GetIndexInfos(ctx context.Context, req *indexpb.GetIndexInfoReq // DescribeIndex describe the index info of the collection. func (c *Client) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRequest) (*indexpb.DescribeIndexResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).DescribeIndex(ctx, req) + return client.DescribeIndex(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -216,11 +216,11 @@ func (c *Client) DescribeIndex(ctx context.Context, req *indexpb.DescribeIndexRe // GetIndexBuildProgress describe the progress of the index. func (c *Client) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetIndexBuildProgressRequest) (*indexpb.GetIndexBuildProgressResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).GetIndexBuildProgress(ctx, req) + return client.GetIndexBuildProgress(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -230,11 +230,11 @@ func (c *Client) GetIndexBuildProgress(ctx context.Context, req *indexpb.GetInde // DropIndex sends the drop index request to IndexCoord. func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).DropIndex(ctx, req) + return client.DropIndex(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -244,11 +244,11 @@ func (c *Client) DropIndex(ctx context.Context, req *indexpb.DropIndexRequest) ( // ShowConfigurations gets specified configurations para of IndexCoord func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).ShowConfigurations(ctx, req) + return client.ShowConfigurations(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -259,11 +259,11 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // GetMetrics gets the metrics info of IndexCoord. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexCoordClient).GetMetrics(ctx, req) + return client.GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/indexnode/client/client.go b/internal/distributed/indexnode/client/client.go index 8cce0c6cb6..9601fbdafa 100644 --- a/internal/distributed/indexnode/client/client.go +++ b/internal/distributed/indexnode/client/client.go @@ -36,7 +36,7 @@ var ClientParams paramtable.GrpcClientConfig // Client is the grpc client of IndexNode. type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[indexpb.IndexNodeClient] addr string } @@ -48,7 +48,7 @@ func NewClient(ctx context.Context, addr string, encryption bool) (*Client, erro ClientParams.InitOnce(typeutil.IndexNodeRole) client := &Client{ addr: addr, - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[indexpb.IndexNodeClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -90,7 +90,7 @@ func (c *Client) Register() error { return nil } -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) indexpb.IndexNodeClient { return indexpb.NewIndexNodeClient(cc) } @@ -100,11 +100,11 @@ func (c *Client) getAddr() (string, error) { // GetComponentStates gets the component states of IndexNode. func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -113,11 +113,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta } func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -127,11 +127,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // CreateJob sends the build index request to IndexNode. func (c *Client) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).CreateJob(ctx, req) + return client.CreateJob(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -141,11 +141,11 @@ func (c *Client) CreateJob(ctx context.Context, req *indexpb.CreateJobRequest) ( // QueryJobs query the task info of the index task. func (c *Client) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest) (*indexpb.QueryJobsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).QueryJobs(ctx, req) + return client.QueryJobs(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -155,11 +155,11 @@ func (c *Client) QueryJobs(ctx context.Context, req *indexpb.QueryJobsRequest) ( // DropJobs query the task info of the index task. func (c *Client) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).DropJobs(ctx, req) + return client.DropJobs(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -169,11 +169,11 @@ func (c *Client) DropJobs(ctx context.Context, req *indexpb.DropJobsRequest) (*c // GetJobStats query the task info of the index task. func (c *Client) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsRequest) (*indexpb.GetJobStatsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).GetJobStats(ctx, req) + return client.GetJobStats(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -183,11 +183,11 @@ func (c *Client) GetJobStats(ctx context.Context, req *indexpb.GetJobStatsReques // ShowConfigurations gets specified configurations para of IndexNode func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).ShowConfigurations(ctx, req) + return client.ShowConfigurations(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -198,11 +198,11 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // GetMetrics gets the metrics info of IndexNode. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client indexpb.IndexNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(indexpb.IndexNodeClient).GetMetrics(ctx, req) + return client.GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/indexnode/client/client_test.go b/internal/distributed/indexnode/client/client_test.go index 656cc98b8f..ab7a28d76b 100644 --- a/internal/distributed/indexnode/client/client_test.go +++ b/internal/distributed/indexnode/client/client_test.go @@ -89,32 +89,32 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r7, err) } - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[indexpb.IndexNodeClient]{ GetGrpcClientErr: errors.New("dummy"), } - newFunc1 := func(cc *grpc.ClientConn) interface{} { + newFunc1 := func(cc *grpc.ClientConn) indexpb.IndexNodeClient { return &mock.GrpcIndexNodeClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc1) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[indexpb.IndexNodeClient]{ GetGrpcClientErr: nil, } - newFunc2 := func(cc *grpc.ClientConn) interface{} { + newFunc2 := func(cc *grpc.ClientConn) indexpb.IndexNodeClient { return &mock.GrpcIndexNodeClient{Err: errors.New("dummy")} } client.grpcClient.SetNewGrpcClientFunc(newFunc2) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[indexpb.IndexNodeClient]{ GetGrpcClientErr: nil, } - newFunc3 := func(cc *grpc.ClientConn) interface{} { + newFunc3 := func(cc *grpc.ClientConn) indexpb.IndexNodeClient { return &mock.GrpcIndexNodeClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc3) diff --git a/internal/distributed/proxy/client/client.go b/internal/distributed/proxy/client/client.go index 7e1e65e5fc..51060eb875 100644 --- a/internal/distributed/proxy/client/client.go +++ b/internal/distributed/proxy/client/client.go @@ -35,7 +35,7 @@ var ClientParams paramtable.GrpcClientConfig // Client is the grpc client for Proxy type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[proxypb.ProxyClient] addr string } @@ -47,7 +47,7 @@ func NewClient(ctx context.Context, addr string) (*Client, error) { ClientParams.InitOnce(typeutil.ProxyRole) client := &Client{ addr: addr, - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[proxypb.ProxyClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -71,7 +71,7 @@ func (c *Client) Init() error { return nil } -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) proxypb.ProxyClient { return proxypb.NewProxyClient(cc) } @@ -96,11 +96,11 @@ func (c *Client) Register() error { // GetComponentStates get the component state. func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -110,11 +110,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta //GetStatisticsChannel return the statistics channel in string func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -124,11 +124,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // InvalidateCollectionMetaCache invalidate collection meta cache func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).InvalidateCollectionMetaCache(ctx, req) + return client.InvalidateCollectionMetaCache(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -137,11 +137,11 @@ func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, req *proxypb } func (c *Client) InvalidateCredentialCache(ctx context.Context, req *proxypb.InvalidateCredCacheRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).InvalidateCredentialCache(ctx, req) + return client.InvalidateCredentialCache(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -150,11 +150,11 @@ func (c *Client) InvalidateCredentialCache(ctx context.Context, req *proxypb.Inv } func (c *Client) UpdateCredentialCache(ctx context.Context, req *proxypb.UpdateCredCacheRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).UpdateCredentialCache(ctx, req) + return client.UpdateCredentialCache(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -163,11 +163,11 @@ func (c *Client) UpdateCredentialCache(ctx context.Context, req *proxypb.UpdateC } func (c *Client) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.RefreshPolicyInfoCacheRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).RefreshPolicyInfoCache(ctx, req) + return client.RefreshPolicyInfoCache(ctx, req) }) if err != nil { return nil, err @@ -178,11 +178,11 @@ func (c *Client) RefreshPolicyInfoCache(ctx context.Context, req *proxypb.Refres // GetProxyMetrics gets the metrics of proxy, it's an internal interface which is different from GetMetrics interface, // because it only obtains the metrics of Proxy, not including the topological metrics of Query cluster and Data cluster. func (c *Client) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).GetProxyMetrics(ctx, req) + return client.GetProxyMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -192,11 +192,11 @@ func (c *Client) GetProxyMetrics(ctx context.Context, req *milvuspb.GetMetricsRe // SetRates notifies Proxy to limit rates of requests. func (c *Client) SetRates(ctx context.Context, req *proxypb.SetRatesRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client proxypb.ProxyClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(proxypb.ProxyClient).SetRates(ctx, req) + return client.SetRates(ctx, req) }) if err != nil { return nil, err diff --git a/internal/distributed/proxy/client/client_test.go b/internal/distributed/proxy/client/client_test.go index ea8fb62f4b..3b41b136c2 100644 --- a/internal/distributed/proxy/client/client_test.go +++ b/internal/distributed/proxy/client/client_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/util/mock" "google.golang.org/grpc" @@ -79,32 +80,32 @@ func Test_NewClient(t *testing.T) { } } - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[proxypb.ProxyClient]{ GetGrpcClientErr: errors.New("dummy"), } - newFunc1 := func(cc *grpc.ClientConn) interface{} { + newFunc1 := func(cc *grpc.ClientConn) proxypb.ProxyClient { return &mock.GrpcProxyClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc1) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[proxypb.ProxyClient]{ GetGrpcClientErr: nil, } - newFunc2 := func(cc *grpc.ClientConn) interface{} { + newFunc2 := func(cc *grpc.ClientConn) proxypb.ProxyClient { return &mock.GrpcProxyClient{Err: errors.New("dummy")} } client.grpcClient.SetNewGrpcClientFunc(newFunc2) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[proxypb.ProxyClient]{ GetGrpcClientErr: nil, } - newFunc3 := func(cc *grpc.ClientConn) interface{} { + newFunc3 := func(cc *grpc.ClientConn) proxypb.ProxyClient { return &mock.GrpcProxyClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc3) diff --git a/internal/distributed/querycoord/client/client.go b/internal/distributed/querycoord/client/client.go index ad150794ec..24caf3e2ca 100644 --- a/internal/distributed/querycoord/client/client.go +++ b/internal/distributed/querycoord/client/client.go @@ -39,7 +39,7 @@ var ClientParams paramtable.GrpcClientConfig // Client is the grpc client of QueryCoord. type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[querypb.QueryCoordClient] sess *sessionutil.Session } @@ -53,7 +53,7 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) ( } ClientParams.InitOnce(typeutil.QueryCoordRole) client := &Client{ - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[querypb.QueryCoordClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -94,7 +94,7 @@ func (c *Client) getQueryCoordAddr() (string, error) { return ms.Address, nil } -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) querypb.QueryCoordClient { return querypb.NewQueryCoordClient(cc) } @@ -115,11 +115,11 @@ func (c *Client) Register() error { // GetComponentStates gets the component states of QueryCoord. func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -129,11 +129,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta // GetTimeTickChannel gets the time tick channel of QueryCoord. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -143,11 +143,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel gets the statistics channel of QueryCoord. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -157,11 +157,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // ShowCollections shows the collections in the QueryCoord. func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).ShowCollections(ctx, req) + return client.ShowCollections(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -171,11 +171,11 @@ func (c *Client) ShowCollections(ctx context.Context, req *querypb.ShowCollectio // LoadCollection loads the data of the specified collections in the QueryCoord. func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).LoadCollection(ctx, req) + return client.LoadCollection(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -185,11 +185,11 @@ func (c *Client) LoadCollection(ctx context.Context, req *querypb.LoadCollection // ReleaseCollection release the data of the specified collections in the QueryCoord. func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).ReleaseCollection(ctx, req) + return client.ReleaseCollection(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -199,11 +199,11 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl // ShowPartitions shows the partitions in the QueryCoord. func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).ShowPartitions(ctx, req) + return client.ShowPartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -213,11 +213,11 @@ func (c *Client) ShowPartitions(ctx context.Context, req *querypb.ShowPartitions // LoadPartitions loads the data of the specified partitions in the QueryCoord. func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).LoadPartitions(ctx, req) + return client.LoadPartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -227,11 +227,11 @@ func (c *Client) LoadPartitions(ctx context.Context, req *querypb.LoadPartitions // ReleasePartitions release the data of the specified partitions in the QueryCoord. func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).ReleasePartitions(ctx, req) + return client.ReleasePartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -241,11 +241,11 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart // GetPartitionStates gets the states of the specified partition. func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetPartitionStates(ctx, req) + return client.GetPartitionStates(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -255,11 +255,11 @@ func (c *Client) GetPartitionStates(ctx context.Context, req *querypb.GetPartiti // GetSegmentInfo gets the information of the specified segment from QueryCoord. func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetSegmentInfo(ctx, req) + return client.GetSegmentInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -269,11 +269,11 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo // LoadBalance migrate the sealed segments on the source node to the dst nodes. func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).LoadBalance(ctx, req) + return client.LoadBalance(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -283,11 +283,11 @@ func (c *Client) LoadBalance(ctx context.Context, req *querypb.LoadBalanceReques // ShowConfigurations gets specified configurations para of QueryCoord func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).ShowConfigurations(ctx, req) + return client.ShowConfigurations(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -298,11 +298,11 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // GetMetrics gets the metrics information of QueryCoord. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetMetrics(ctx, req) + return client.GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -312,11 +312,11 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest // GetReplicas gets the replicas of a certain collection. func (c *Client) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasRequest) (*milvuspb.GetReplicasResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetReplicas(ctx, req) + return client.GetReplicas(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -326,11 +326,11 @@ func (c *Client) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReque // GetShardLeaders gets the shard leaders of a certain collection. func (c *Client) GetShardLeaders(ctx context.Context, req *querypb.GetShardLeadersRequest) (*querypb.GetShardLeadersResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryCoordClient).GetShardLeaders(ctx, req) + return client.GetShardLeaders(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/querycoord/client/client_test.go b/internal/distributed/querycoord/client/client_test.go index 8a77255e55..d47864837c 100644 --- a/internal/distributed/querycoord/client/client_test.go +++ b/internal/distributed/querycoord/client/client_test.go @@ -21,6 +21,7 @@ import ( "errors" "testing" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/mock" "google.golang.org/grpc" @@ -50,7 +51,7 @@ func Test_NewClient(t *testing.T) { assert.Nil(t, err) checkFunc := func(retNotNil bool) { - retCheck := func(notNil bool, ret interface{}, err error) { + retCheck := func(notNil bool, ret any, err error) { if notNil { assert.NotNil(t, ret) assert.Nil(t, err) @@ -115,22 +116,22 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r19, err) } - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[querypb.QueryCoordClient]{ GetGrpcClientErr: errors.New("dummy"), } - newFunc1 := func(cc *grpc.ClientConn) interface{} { + newFunc1 := func(cc *grpc.ClientConn) querypb.QueryCoordClient { return &mock.GrpcQueryCoordClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc1) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[querypb.QueryCoordClient]{ GetGrpcClientErr: nil, } - newFunc2 := func(cc *grpc.ClientConn) interface{} { + newFunc2 := func(cc *grpc.ClientConn) querypb.QueryCoordClient { return &mock.GrpcQueryCoordClient{Err: errors.New("dummy")} } @@ -138,11 +139,11 @@ func Test_NewClient(t *testing.T) { checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[querypb.QueryCoordClient]{ GetGrpcClientErr: nil, } - newFunc3 := func(cc *grpc.ClientConn) interface{} { + newFunc3 := func(cc *grpc.ClientConn) querypb.QueryCoordClient { return &mock.GrpcQueryCoordClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc3) diff --git a/internal/distributed/querynode/client/client.go b/internal/distributed/querynode/client/client.go index 17a14aafdc..13610c6cd3 100644 --- a/internal/distributed/querynode/client/client.go +++ b/internal/distributed/querynode/client/client.go @@ -36,7 +36,7 @@ var ClientParams paramtable.GrpcClientConfig // Client is the grpc client of QueryNode. type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[querypb.QueryNodeClient] addr string } @@ -48,7 +48,7 @@ func NewClient(ctx context.Context, addr string) (*Client, error) { ClientParams.InitOnce(typeutil.QueryNodeRole) client := &Client{ addr: addr, - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[querypb.QueryNodeClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -88,7 +88,7 @@ func (c *Client) Register() error { return nil } -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) querypb.QueryNodeClient { return querypb.NewQueryNodeClient(cc) } @@ -98,11 +98,11 @@ func (c *Client) getAddr() (string, error) { // GetComponentStates gets the component states of QueryNode. func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -112,11 +112,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta // GetTimeTickChannel gets the time tick channel of QueryNode. func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -126,11 +126,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel gets the statistics channel of QueryNode. func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -140,11 +140,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // WatchDmChannels watches the channels about data manipulation. func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).WatchDmChannels(ctx, req) + return client.WatchDmChannels(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -154,11 +154,11 @@ func (c *Client) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChanne // UnsubDmChannel unsubscribes the channels about data manipulation. func (c *Client) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannelRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).UnsubDmChannel(ctx, req) + return client.UnsubDmChannel(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -168,11 +168,11 @@ func (c *Client) UnsubDmChannel(ctx context.Context, req *querypb.UnsubDmChannel // LoadSegments loads the segments to search. func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).LoadSegments(ctx, req) + return client.LoadSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -182,11 +182,11 @@ func (c *Client) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequ // ReleaseCollection releases the data of the specified collection in QueryNode. func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).ReleaseCollection(ctx, req) + return client.ReleaseCollection(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -196,11 +196,11 @@ func (c *Client) ReleaseCollection(ctx context.Context, req *querypb.ReleaseColl // ReleasePartitions releases the data of the specified partitions in QueryNode. func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).ReleasePartitions(ctx, req) + return client.ReleasePartitions(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -210,11 +210,11 @@ func (c *Client) ReleasePartitions(ctx context.Context, req *querypb.ReleasePart // ReleaseSegments releases the data of the specified segments in QueryNode. func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).ReleaseSegments(ctx, req) + return client.ReleaseSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -224,11 +224,11 @@ func (c *Client) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmen // Search performs replica search tasks in QueryNode. func (c *Client) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { - ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.Call(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).Search(ctx, req) + return client.Search(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -238,11 +238,11 @@ func (c *Client) Search(ctx context.Context, req *querypb.SearchRequest) (*inter // Query performs replica query tasks in QueryNode. func (c *Client) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) { - ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.Call(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).Query(ctx, req) + return client.Query(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -252,11 +252,11 @@ func (c *Client) Query(ctx context.Context, req *querypb.QueryRequest) (*interna // GetSegmentInfo gets the information of the specified segments in QueryNode. func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).GetSegmentInfo(ctx, req) + return client.GetSegmentInfo(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -266,11 +266,11 @@ func (c *Client) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfo // SyncReplicaSegments syncs replica node segments information to shard leaders. func (c *Client) SyncReplicaSegments(ctx context.Context, req *querypb.SyncReplicaSegmentsRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).SyncReplicaSegments(ctx, req) + return client.SyncReplicaSegments(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -280,11 +280,11 @@ func (c *Client) SyncReplicaSegments(ctx context.Context, req *querypb.SyncRepli // ShowConfigurations gets specified configurations para of QueryNode func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).ShowConfigurations(ctx, req) + return client.ShowConfigurations(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -295,11 +295,11 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // GetMetrics gets the metrics information of QueryNode. func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).GetMetrics(ctx, req) + return client.GetMetrics(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -308,11 +308,11 @@ func (c *Client) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest } func (c *Client) GetStatistics(ctx context.Context, request *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) { - ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.Call(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).GetStatistics(ctx, request) + return client.GetStatistics(ctx, request) }) if err != nil || ret == nil { return nil, err @@ -321,11 +321,11 @@ func (c *Client) GetStatistics(ctx context.Context, request *querypb.GetStatisti } func (c *Client) GetDataDistribution(ctx context.Context, req *querypb.GetDataDistributionRequest) (*querypb.GetDataDistributionResponse, error) { - ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.Call(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).GetDataDistribution(ctx, req) + return client.GetDataDistribution(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -334,11 +334,11 @@ func (c *Client) GetDataDistribution(ctx context.Context, req *querypb.GetDataDi } func (c *Client) SyncDistribution(ctx context.Context, req *querypb.SyncDistributionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.Call(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.Call(ctx, func(client querypb.QueryNodeClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(querypb.QueryNodeClient).SyncDistribution(ctx, req) + return client.SyncDistribution(ctx, req) }) if err != nil || ret == nil { return nil, err diff --git a/internal/distributed/querynode/client/client_test.go b/internal/distributed/querynode/client/client_test.go index 8c3c03f9e0..8361060084 100644 --- a/internal/distributed/querynode/client/client_test.go +++ b/internal/distributed/querynode/client/client_test.go @@ -21,6 +21,7 @@ import ( "errors" "testing" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/stretchr/testify/assert" @@ -50,7 +51,7 @@ func Test_NewClient(t *testing.T) { ctx, cancel := context.WithCancel(ctx) checkFunc := func(retNotNil bool) { - retCheck := func(notNil bool, ret interface{}, err error) { + retCheck := func(notNil bool, ret any, err error) { if notNil { assert.NotNil(t, ret) assert.Nil(t, err) @@ -106,22 +107,22 @@ func Test_NewClient(t *testing.T) { retCheck(retNotNil, r18, err) } - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[querypb.QueryNodeClient]{ GetGrpcClientErr: errors.New("dummy"), } - newFunc1 := func(cc *grpc.ClientConn) interface{} { + newFunc1 := func(cc *grpc.ClientConn) querypb.QueryNodeClient { return &mock.GrpcQueryNodeClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc1) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[querypb.QueryNodeClient]{ GetGrpcClientErr: nil, } - newFunc2 := func(cc *grpc.ClientConn) interface{} { + newFunc2 := func(cc *grpc.ClientConn) querypb.QueryNodeClient { return &mock.GrpcQueryNodeClient{Err: errors.New("dummy")} } @@ -129,11 +130,11 @@ func Test_NewClient(t *testing.T) { checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[querypb.QueryNodeClient]{ GetGrpcClientErr: nil, } - newFunc3 := func(cc *grpc.ClientConn) interface{} { + newFunc3 := func(cc *grpc.ClientConn) querypb.QueryNodeClient { return &mock.GrpcQueryNodeClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc3) @@ -141,7 +142,7 @@ func Test_NewClient(t *testing.T) { checkFunc(true) // ctx canceled - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[querypb.QueryNodeClient]{ GetGrpcClientErr: nil, } client.grpcClient.SetNewGrpcClientFunc(newFunc1) diff --git a/internal/distributed/rootcoord/client/client.go b/internal/distributed/rootcoord/client/client.go index 58b291f7b9..807af1b726 100644 --- a/internal/distributed/rootcoord/client/client.go +++ b/internal/distributed/rootcoord/client/client.go @@ -40,7 +40,7 @@ var ClientParams paramtable.GrpcClientConfig // Client grpc client type Client struct { - grpcClient grpcclient.GrpcClient + grpcClient grpcclient.GrpcClient[rootcoordpb.RootCoordClient] sess *sessionutil.Session } @@ -58,7 +58,7 @@ func NewClient(ctx context.Context, metaRoot string, etcdCli *clientv3.Client) ( } ClientParams.InitOnce(typeutil.RootCoordRole) client := &Client{ - grpcClient: &grpcclient.ClientBase{ + grpcClient: &grpcclient.ClientBase[rootcoordpb.RootCoordClient]{ ClientMaxRecvSize: ClientParams.ClientMaxRecvSize, ClientMaxSendSize: ClientParams.ClientMaxSendSize, DialTimeout: ClientParams.DialTimeout, @@ -84,7 +84,7 @@ func (c *Client) Init() error { return nil } -func (c *Client) newGrpcClient(cc *grpc.ClientConn) interface{} { +func (c *Client) newGrpcClient(cc *grpc.ClientConn) rootcoordpb.RootCoordClient { return rootcoordpb.NewRootCoordClient(cc) } @@ -121,11 +121,11 @@ func (c *Client) Register() error { // GetComponentStates TODO: timeout need to be propagated through ctx func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentStates, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) + return client.GetComponentStates(ctx, &milvuspb.GetComponentStatesRequest{}) }) if err != nil || ret == nil { return nil, err @@ -135,11 +135,11 @@ func (c *Client) GetComponentStates(ctx context.Context) (*milvuspb.ComponentSta // GetTimeTickChannel get timetick channel name func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + return client.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -149,11 +149,11 @@ func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringRespon // GetStatisticsChannel just define a channel, not used currently func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + return client.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) }) if err != nil || ret == nil { return nil, err @@ -163,11 +163,11 @@ func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResp // CreateCollection create collection func (c *Client) CreateCollection(ctx context.Context, in *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).CreateCollection(ctx, in) + return client.CreateCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -177,11 +177,11 @@ func (c *Client) CreateCollection(ctx context.Context, in *milvuspb.CreateCollec // DropCollection drop collection func (c *Client) DropCollection(ctx context.Context, in *milvuspb.DropCollectionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).DropCollection(ctx, in) + return client.DropCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -191,11 +191,11 @@ func (c *Client) DropCollection(ctx context.Context, in *milvuspb.DropCollection // HasCollection check collection existence func (c *Client) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRequest) (*milvuspb.BoolResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).HasCollection(ctx, in) + return client.HasCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -205,11 +205,11 @@ func (c *Client) HasCollection(ctx context.Context, in *milvuspb.HasCollectionRe // DescribeCollection return collection info func (c *Client) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).DescribeCollection(ctx, in) + return client.DescribeCollection(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -219,11 +219,11 @@ func (c *Client) DescribeCollection(ctx context.Context, in *milvuspb.DescribeCo // ShowCollections list all collection names func (c *Client) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectionsRequest) (*milvuspb.ShowCollectionsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ShowCollections(ctx, in) + return client.ShowCollections(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -232,11 +232,11 @@ func (c *Client) ShowCollections(ctx context.Context, in *milvuspb.ShowCollectio } func (c *Client) AlterCollection(ctx context.Context, request *milvuspb.AlterCollectionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).AlterCollection(ctx, request) + return client.AlterCollection(ctx, request) }) if err != nil || ret == nil { return nil, err @@ -246,11 +246,11 @@ func (c *Client) AlterCollection(ctx context.Context, request *milvuspb.AlterCol // CreatePartition create partition func (c *Client) CreatePartition(ctx context.Context, in *milvuspb.CreatePartitionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).CreatePartition(ctx, in) + return client.CreatePartition(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -260,11 +260,11 @@ func (c *Client) CreatePartition(ctx context.Context, in *milvuspb.CreatePartiti // DropPartition drop partition func (c *Client) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).DropPartition(ctx, in) + return client.DropPartition(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -274,11 +274,11 @@ func (c *Client) DropPartition(ctx context.Context, in *milvuspb.DropPartitionRe // HasPartition check partition existence func (c *Client) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequest) (*milvuspb.BoolResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).HasPartition(ctx, in) + return client.HasPartition(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -288,11 +288,11 @@ func (c *Client) HasPartition(ctx context.Context, in *milvuspb.HasPartitionRequ // ShowPartitions list all partitions in collection func (c *Client) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ShowPartitions(ctx, in) + return client.ShowPartitions(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -302,11 +302,11 @@ func (c *Client) ShowPartitions(ctx context.Context, in *milvuspb.ShowPartitions // AllocTimestamp global timestamp allocator func (c *Client) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimestampRequest) (*rootcoordpb.AllocTimestampResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).AllocTimestamp(ctx, in) + return client.AllocTimestamp(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -316,11 +316,11 @@ func (c *Client) AllocTimestamp(ctx context.Context, in *rootcoordpb.AllocTimest // AllocID global ID allocator func (c *Client) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (*rootcoordpb.AllocIDResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).AllocID(ctx, in) + return client.AllocID(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -330,11 +330,11 @@ func (c *Client) AllocID(ctx context.Context, in *rootcoordpb.AllocIDRequest) (* // UpdateChannelTimeTick used to handle ChannelTimeTickMsg func (c *Client) UpdateChannelTimeTick(ctx context.Context, in *internalpb.ChannelTimeTickMsg) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).UpdateChannelTimeTick(ctx, in) + return client.UpdateChannelTimeTick(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -344,11 +344,11 @@ func (c *Client) UpdateChannelTimeTick(ctx context.Context, in *internalpb.Chann // ShowSegments list all segments func (c *Client) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequest) (*milvuspb.ShowSegmentsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ShowSegments(ctx, in) + return client.ShowSegments(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -358,11 +358,11 @@ func (c *Client) ShowSegments(ctx context.Context, in *milvuspb.ShowSegmentsRequ // InvalidateCollectionMetaCache notifies RootCoord to release the collection cache in Proxies. func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb.InvalidateCollMetaCacheRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).InvalidateCollectionMetaCache(ctx, in) + return client.InvalidateCollectionMetaCache(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -372,11 +372,11 @@ func (c *Client) InvalidateCollectionMetaCache(ctx context.Context, in *proxypb. // ShowConfigurations gets specified configurations para of RootCoord func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowConfigurationsRequest) (*internalpb.ShowConfigurationsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ShowConfigurations(ctx, req) + return client.ShowConfigurations(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -387,11 +387,11 @@ func (c *Client) ShowConfigurations(ctx context.Context, req *internalpb.ShowCon // GetMetrics get metrics func (c *Client) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).GetMetrics(ctx, in) + return client.GetMetrics(ctx, in) }) if err != nil || ret == nil { return nil, err @@ -401,11 +401,11 @@ func (c *Client) GetMetrics(ctx context.Context, in *milvuspb.GetMetricsRequest) // CreateAlias create collection alias func (c *Client) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).CreateAlias(ctx, req) + return client.CreateAlias(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -415,11 +415,11 @@ func (c *Client) CreateAlias(ctx context.Context, req *milvuspb.CreateAliasReque // DropAlias drop collection alias func (c *Client) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).DropAlias(ctx, req) + return client.DropAlias(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -429,11 +429,11 @@ func (c *Client) DropAlias(ctx context.Context, req *milvuspb.DropAliasRequest) // AlterAlias alter collection alias func (c *Client) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).AlterAlias(ctx, req) + return client.AlterAlias(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -443,11 +443,11 @@ func (c *Client) AlterAlias(ctx context.Context, req *milvuspb.AlterAliasRequest // Import data files(json, numpy, etc.) on MinIO/S3 storage, read and parse them into sealed segments func (c *Client) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milvuspb.ImportResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).Import(ctx, req) + return client.Import(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -457,11 +457,11 @@ func (c *Client) Import(ctx context.Context, req *milvuspb.ImportRequest) (*milv // Check import task state from datanode func (c *Client) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).GetImportState(ctx, req) + return client.GetImportState(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -471,11 +471,11 @@ func (c *Client) GetImportState(ctx context.Context, req *milvuspb.GetImportStat // List id array of all import tasks func (c *Client) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTasksRequest) (*milvuspb.ListImportTasksResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ListImportTasks(ctx, req) + return client.ListImportTasks(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -485,11 +485,11 @@ func (c *Client) ListImportTasks(ctx context.Context, req *milvuspb.ListImportTa // Report impot task state to rootcoord func (c *Client) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ReportImport(ctx, req) + return client.ReportImport(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -498,11 +498,11 @@ func (c *Client) ReportImport(ctx context.Context, req *rootcoordpb.ImportResult } func (c *Client) CreateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).CreateCredential(ctx, req) + return client.CreateCredential(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -511,11 +511,11 @@ func (c *Client) CreateCredential(ctx context.Context, req *internalpb.Credentia } func (c *Client) GetCredential(ctx context.Context, req *rootcoordpb.GetCredentialRequest) (*rootcoordpb.GetCredentialResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).GetCredential(ctx, req) + return client.GetCredential(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -524,11 +524,11 @@ func (c *Client) GetCredential(ctx context.Context, req *rootcoordpb.GetCredenti } func (c *Client) UpdateCredential(ctx context.Context, req *internalpb.CredentialInfo) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).UpdateCredential(ctx, req) + return client.UpdateCredential(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -537,11 +537,11 @@ func (c *Client) UpdateCredential(ctx context.Context, req *internalpb.Credentia } func (c *Client) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCredentialRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).DeleteCredential(ctx, req) + return client.DeleteCredential(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -550,11 +550,11 @@ func (c *Client) DeleteCredential(ctx context.Context, req *milvuspb.DeleteCrede } func (c *Client) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersRequest) (*milvuspb.ListCredUsersResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ListCredUsers(ctx, req) + return client.ListCredUsers(ctx, req) }) if err != nil || ret == nil { return nil, err @@ -563,11 +563,11 @@ func (c *Client) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUsersR } func (c *Client) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).CreateRole(ctx, req) + return client.CreateRole(ctx, req) }) if err != nil { return nil, err @@ -576,11 +576,11 @@ func (c *Client) CreateRole(ctx context.Context, req *milvuspb.CreateRoleRequest } func (c *Client) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).DropRole(ctx, req) + return client.DropRole(ctx, req) }) if err != nil { return nil, err @@ -589,11 +589,11 @@ func (c *Client) DropRole(ctx context.Context, req *milvuspb.DropRoleRequest) (* } func (c *Client) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserRoleRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).OperateUserRole(ctx, req) + return client.OperateUserRole(ctx, req) }) if err != nil { return nil, err @@ -602,11 +602,11 @@ func (c *Client) OperateUserRole(ctx context.Context, req *milvuspb.OperateUserR } func (c *Client) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest) (*milvuspb.SelectRoleResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).SelectRole(ctx, req) + return client.SelectRole(ctx, req) }) if err != nil { return nil, err @@ -615,11 +615,11 @@ func (c *Client) SelectRole(ctx context.Context, req *milvuspb.SelectRoleRequest } func (c *Client) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest) (*milvuspb.SelectUserResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).SelectUser(ctx, req) + return client.SelectUser(ctx, req) }) if err != nil { return nil, err @@ -628,11 +628,11 @@ func (c *Client) SelectUser(ctx context.Context, req *milvuspb.SelectUserRequest } func (c *Client) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePrivilegeRequest) (*commonpb.Status, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).OperatePrivilege(ctx, req) + return client.OperatePrivilege(ctx, req) }) if err != nil { return nil, err @@ -641,11 +641,11 @@ func (c *Client) OperatePrivilege(ctx context.Context, req *milvuspb.OperatePriv } func (c *Client) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantRequest) (*milvuspb.SelectGrantResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).SelectGrant(ctx, req) + return client.SelectGrant(ctx, req) }) if err != nil { return nil, err @@ -654,11 +654,11 @@ func (c *Client) SelectGrant(ctx context.Context, req *milvuspb.SelectGrantReque } func (c *Client) ListPolicy(ctx context.Context, req *internalpb.ListPolicyRequest) (*internalpb.ListPolicyResponse, error) { - ret, err := c.grpcClient.ReCall(ctx, func(client interface{}) (interface{}, error) { + ret, err := c.grpcClient.ReCall(ctx, func(client rootcoordpb.RootCoordClient) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } - return client.(rootcoordpb.RootCoordClient).ListPolicy(ctx, req) + return client.ListPolicy(ctx, req) }) if err != nil { return nil, err diff --git a/internal/distributed/rootcoord/client/client_test.go b/internal/distributed/rootcoord/client/client_test.go index 97271b933e..6176f669a0 100644 --- a/internal/distributed/rootcoord/client/client_test.go +++ b/internal/distributed/rootcoord/client/client_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/util/mock" "google.golang.org/grpc" @@ -214,22 +215,22 @@ func Test_NewClient(t *testing.T) { } } - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[rootcoordpb.RootCoordClient]{ GetGrpcClientErr: errors.New("dummy"), } - newFunc1 := func(cc *grpc.ClientConn) interface{} { + newFunc1 := func(cc *grpc.ClientConn) rootcoordpb.RootCoordClient { return &mock.GrpcRootCoordClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc1) checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[rootcoordpb.RootCoordClient]{ GetGrpcClientErr: nil, } - newFunc2 := func(cc *grpc.ClientConn) interface{} { + newFunc2 := func(cc *grpc.ClientConn) rootcoordpb.RootCoordClient { return &mock.GrpcRootCoordClient{Err: errors.New("dummy")} } @@ -237,11 +238,11 @@ func Test_NewClient(t *testing.T) { checkFunc(false) - client.grpcClient = &mock.GRPCClientBase{ + client.grpcClient = &mock.GRPCClientBase[rootcoordpb.RootCoordClient]{ GetGrpcClientErr: nil, } - newFunc3 := func(cc *grpc.ClientConn) interface{} { + newFunc3 := func(cc *grpc.ClientConn) rootcoordpb.RootCoordClient { return &mock.GrpcRootCoordClient{Err: nil} } client.grpcClient.SetNewGrpcClientFunc(newFunc3) diff --git a/internal/util/generic/generic.go b/internal/util/generic/generic.go new file mode 100644 index 0000000000..cbc51eee17 --- /dev/null +++ b/internal/util/generic/generic.go @@ -0,0 +1,31 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package generic + +import "reflect" + +func Zero[T any]() T { + return *new(T) +} + +func IsZero[T any](v T) bool { + return reflect.ValueOf(&v).Elem().IsZero() +} + +func Equal(v1, v2 any) bool { + return v1 == v2 +} diff --git a/internal/util/grpcclient/client.go b/internal/util/grpcclient/client.go index 80b00df0ab..4f6bad695c 100644 --- a/internal/util/grpcclient/client.go +++ b/internal/util/grpcclient/client.go @@ -34,28 +34,29 @@ import ( "github.com/milvus-io/milvus/internal/util" "github.com/milvus-io/milvus/internal/util/crypto" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/generic" "github.com/milvus-io/milvus/internal/util/trace" ) // GrpcClient abstracts client of grpc -type GrpcClient interface { +type GrpcClient[T any] interface { SetRole(string) GetRole() string SetGetAddrFunc(func() (string, error)) EnableEncryption() - SetNewGrpcClientFunc(func(cc *grpc.ClientConn) interface{}) - GetGrpcClient(ctx context.Context) (interface{}, error) - ReCall(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) - Call(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) + SetNewGrpcClientFunc(func(cc *grpc.ClientConn) T) + GetGrpcClient(ctx context.Context) (T, error) + ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error) + Call(ctx context.Context, caller func(client T) (any, error)) (any, error) Close() error } // ClientBase is a base of grpc client -type ClientBase struct { +type ClientBase[T any] struct { getAddrFunc func() (string, error) - newGrpcClient func(cc *grpc.ClientConn) interface{} + newGrpcClient func(cc *grpc.ClientConn) T - grpcClient interface{} + grpcClient T encryption bool conn *grpc.ClientConn grpcClientMtx sync.RWMutex @@ -75,34 +76,34 @@ type ClientBase struct { } // SetRole sets role of client -func (c *ClientBase) SetRole(role string) { +func (c *ClientBase[T]) SetRole(role string) { c.role = role } // GetRole returns role of client -func (c *ClientBase) GetRole() string { +func (c *ClientBase[T]) GetRole() string { return c.role } // SetGetAddrFunc sets getAddrFunc of client -func (c *ClientBase) SetGetAddrFunc(f func() (string, error)) { +func (c *ClientBase[T]) SetGetAddrFunc(f func() (string, error)) { c.getAddrFunc = f } -func (c *ClientBase) EnableEncryption() { +func (c *ClientBase[T]) EnableEncryption() { c.encryption = true } // SetNewGrpcClientFunc sets newGrpcClient of client -func (c *ClientBase) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) interface{}) { +func (c *ClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) { c.newGrpcClient = f } // GetGrpcClient returns grpc client -func (c *ClientBase) GetGrpcClient(ctx context.Context) (interface{}, error) { +func (c *ClientBase[T]) GetGrpcClient(ctx context.Context) (T, error) { c.grpcClientMtx.RLock() - if c.grpcClient != nil { + if !generic.IsZero(c.grpcClient) { defer c.grpcClientMtx.RUnlock() return c.grpcClient, nil } @@ -111,36 +112,35 @@ func (c *ClientBase) GetGrpcClient(ctx context.Context) (interface{}, error) { c.grpcClientMtx.Lock() defer c.grpcClientMtx.Unlock() - if c.grpcClient != nil { + if !generic.IsZero(c.grpcClient) { return c.grpcClient, nil } err := c.connect(ctx) if err != nil { - return nil, err + return generic.Zero[T](), err } return c.grpcClient, nil } -func (c *ClientBase) resetConnection(client interface{}) { +func (c *ClientBase[T]) resetConnection(client T) { c.grpcClientMtx.Lock() defer c.grpcClientMtx.Unlock() - if c.grpcClient == nil { + if generic.IsZero(c.grpcClient) { return } - - if client != c.grpcClient { + if !generic.Equal(client, c.grpcClient) { return } if c.conn != nil { _ = c.conn.Close() } c.conn = nil - c.grpcClient = nil + c.grpcClient = generic.Zero[T]() } -func (c *ClientBase) connect(ctx context.Context) error { +func (c *ClientBase[T]) connect(ctx context.Context) error { addr, err := c.getAddrFunc() if err != nil { log.Error("failed to get client address", zap.Error(err)) @@ -240,10 +240,10 @@ func (c *ClientBase) connect(ctx context.Context) error { return nil } -func (c *ClientBase) callOnce(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { +func (c *ClientBase[T]) callOnce(ctx context.Context, caller func(client T) (any, error)) (any, error) { client, err := c.GetGrpcClient(ctx) if err != nil { - return nil, err + return generic.Zero[T](), err } ret, err2 := caller(client) @@ -252,11 +252,11 @@ func (c *ClientBase) callOnce(ctx context.Context, caller func(client interface{ } if !funcutil.CheckCtxValid(ctx) { - return nil, err2 + return generic.Zero[T](), err2 } if !funcutil.IsGrpcErr(err2) { log.Debug("ClientBase:isNotGrpcErr", zap.Error(err2)) - return nil, err2 + return generic.Zero[T](), err2 } log.Debug(c.GetRole()+" ClientBase grpc error, start to reset connection", zap.Error(err2)) c.resetConnection(client) @@ -264,24 +264,24 @@ func (c *ClientBase) callOnce(ctx context.Context, caller func(client interface{ } // Call does a grpc call -func (c *ClientBase) Call(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { +func (c *ClientBase[T]) Call(ctx context.Context, caller func(client T) (any, error)) (any, error) { if !funcutil.CheckCtxValid(ctx) { - return nil, ctx.Err() + return generic.Zero[T](), ctx.Err() } ret, err := c.callOnce(ctx, caller) if err != nil { traceErr := fmt.Errorf("err: %w\n, %s", err, trace.StackTrace()) log.Error("ClientBase Call grpc first call get error", zap.String("role", c.GetRole()), zap.Error(traceErr)) - return nil, traceErr + return generic.Zero[T](), traceErr } return ret, err } // ReCall does the grpc call twice -func (c *ClientBase) ReCall(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { +func (c *ClientBase[T]) ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error) { if !funcutil.CheckCtxValid(ctx) { - return nil, ctx.Err() + return generic.Zero[T](), ctx.Err() } ret, err := c.callOnce(ctx, caller) @@ -293,20 +293,20 @@ func (c *ClientBase) ReCall(ctx context.Context, caller func(client interface{}) log.Warn(c.GetRole()+" ClientBase ReCall grpc first call get error ", zap.Error(traceErr)) if !funcutil.CheckCtxValid(ctx) { - return nil, ctx.Err() + return generic.Zero[T](), ctx.Err() } ret, err = c.callOnce(ctx, caller) if err != nil { traceErr = fmt.Errorf("err: %w\n, %s", err, trace.StackTrace()) log.Error("ClientBase ReCall grpc second call get error", zap.String("role", c.GetRole()), zap.Error(traceErr)) - return nil, traceErr + return generic.Zero[T](), traceErr } return ret, err } // Close close the client connection -func (c *ClientBase) Close() error { +func (c *ClientBase[T]) Close() error { c.grpcClientMtx.Lock() defer c.grpcClientMtx.Unlock() if c.conn != nil { diff --git a/internal/util/grpcclient/client_test.go b/internal/util/grpcclient/client_test.go index 246b876ab0..8e45a96eec 100644 --- a/internal/util/grpcclient/client_test.go +++ b/internal/util/grpcclient/client_test.go @@ -40,20 +40,20 @@ import ( ) func TestClientBase_SetRole(t *testing.T) { - base := ClientBase{} + base := ClientBase[any]{} expect := "abc" base.SetRole("abc") assert.Equal(t, expect, base.GetRole()) } func TestClientBase_GetRole(t *testing.T) { - base := ClientBase{} + base := ClientBase[any]{} assert.Equal(t, "", base.GetRole()) } func TestClientBase_connect(t *testing.T) { t.Run("failed to connect", func(t *testing.T) { - base := ClientBase{ + base := ClientBase[any]{ getAddrFunc: func() (string, error) { return "", nil }, @@ -66,7 +66,7 @@ func TestClientBase_connect(t *testing.T) { t.Run("failed to get addr", func(t *testing.T) { errMock := errors.New("mocked") - base := ClientBase{ + base := ClientBase[any]{ getAddrFunc: func() (string, error) { return "", errMock }, @@ -80,13 +80,13 @@ func TestClientBase_connect(t *testing.T) { func TestClientBase_Call(t *testing.T) { // mock client with nothing - base := ClientBase{} + base := ClientBase[any]{} base.grpcClientMtx.Lock() base.grpcClient = struct{}{} base.grpcClientMtx.Unlock() t.Run("Call normal return", func(t *testing.T) { - _, err := base.Call(context.Background(), func(client interface{}) (interface{}, error) { + _, err := base.Call(context.Background(), func(client any) (any, error) { return struct{}{}, nil }) assert.NoError(t, err) @@ -95,7 +95,7 @@ func TestClientBase_Call(t *testing.T) { t.Run("Call with canceled context", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := base.Call(ctx, func(client interface{}) (interface{}, error) { + _, err := base.Call(ctx, func(client any) (any, error) { return struct{}{}, nil }) assert.Error(t, err) @@ -105,7 +105,7 @@ func TestClientBase_Call(t *testing.T) { t.Run("Call canceled in caller func", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) errMock := errors.New("mocked") - _, err := base.Call(ctx, func(client interface{}) (interface{}, error) { + _, err := base.Call(ctx, func(client any) (any, error) { cancel() return nil, errMock }) @@ -121,7 +121,7 @@ func TestClientBase_Call(t *testing.T) { t.Run("Call canceled in caller func", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) errMock := errors.New("mocked") - _, err := base.Call(ctx, func(client interface{}) (interface{}, error) { + _, err := base.Call(ctx, func(client any) (any, error) { cancel() return nil, errMock }) @@ -138,7 +138,7 @@ func TestClientBase_Call(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() errMock := errors.New("mocked") - _, err := base.Call(ctx, func(client interface{}) (interface{}, error) { + _, err := base.Call(ctx, func(client any) (any, error) { return nil, errMock }) @@ -154,7 +154,7 @@ func TestClientBase_Call(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() errGrpc := status.Error(codes.Unknown, "mocked") - _, err := base.Call(ctx, func(client interface{}) (interface{}, error) { + _, err := base.Call(ctx, func(client any) (any, error) { return nil, errGrpc }) @@ -175,7 +175,7 @@ func TestClientBase_Call(t *testing.T) { t.Run("Call with connect failure", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - _, err := base.Call(ctx, func(client interface{}) (interface{}, error) { + _, err := base.Call(ctx, func(client any) (any, error) { return struct{}{}, nil }) assert.Error(t, err) @@ -185,13 +185,13 @@ func TestClientBase_Call(t *testing.T) { func TestClientBase_Recall(t *testing.T) { // mock client with nothing - base := ClientBase{} + base := ClientBase[any]{} base.grpcClientMtx.Lock() base.grpcClient = struct{}{} base.grpcClientMtx.Unlock() t.Run("Recall normal return", func(t *testing.T) { - _, err := base.ReCall(context.Background(), func(client interface{}) (interface{}, error) { + _, err := base.ReCall(context.Background(), func(client any) (any, error) { return struct{}{}, nil }) assert.NoError(t, err) @@ -200,7 +200,7 @@ func TestClientBase_Recall(t *testing.T) { t.Run("ReCall with canceled context", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - _, err := base.ReCall(ctx, func(client interface{}) (interface{}, error) { + _, err := base.ReCall(ctx, func(client any) (any, error) { return struct{}{}, nil }) assert.Error(t, err) @@ -212,7 +212,7 @@ func TestClientBase_Recall(t *testing.T) { defer cancel() flag := false var mut sync.Mutex - _, err := base.ReCall(ctx, func(client interface{}) (interface{}, error) { + _, err := base.ReCall(ctx, func(client any) (any, error) { mut.Lock() defer mut.Unlock() if flag { @@ -227,7 +227,7 @@ func TestClientBase_Recall(t *testing.T) { t.Run("ReCall canceled in caller func", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) errMock := errors.New("mocked") - _, err := base.ReCall(ctx, func(client interface{}) (interface{}, error) { + _, err := base.ReCall(ctx, func(client any) (any, error) { cancel() return nil, errMock }) @@ -248,7 +248,7 @@ func TestClientBase_Recall(t *testing.T) { t.Run("ReCall with connect failure", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - _, err := base.ReCall(ctx, func(client interface{}) (interface{}, error) { + _, err := base.ReCall(ctx, func(client any) (any, error) { return struct{}{}, nil }) assert.Error(t, err) @@ -304,7 +304,7 @@ func TestClientBase_RetryPolicy(t *testing.T) { }() defer s.Stop() - clientBase := ClientBase{ + clientBase := ClientBase[helloworld.GreeterClient]{ ClientMaxRecvSize: 1 * 1024 * 1024, ClientMaxSendSize: 1 * 1024 * 1024, DialTimeout: 60 * time.Second, @@ -320,17 +320,16 @@ func TestClientBase_RetryPolicy(t *testing.T) { clientBase.SetGetAddrFunc(func() (string, error) { return address, nil }) - clientBase.SetNewGrpcClientFunc(func(cc *grpc.ClientConn) interface{} { + clientBase.SetNewGrpcClientFunc(func(cc *grpc.ClientConn) helloworld.GreeterClient { return helloworld.NewGreeterClient(cc) }) defer clientBase.Close() ctx := context.Background() name := fmt.Sprintf("hello world %d", time.Now().Second()) - res, err := clientBase.Call(ctx, func(client interface{}) (interface{}, error) { - c := client.(helloworld.GreeterClient) + res, err := clientBase.Call(ctx, func(client helloworld.GreeterClient) (any, error) { fmt.Println("client base...") - return c.SayHello(ctx, &helloworld.HelloRequest{Name: name}) + return client.SayHello(ctx, &helloworld.HelloRequest{Name: name}) }) assert.Nil(t, err) assert.Equal(t, res.(*helloworld.HelloReply).Message, strings.ToUpper(name)) diff --git a/internal/util/mock/grpcclient.go b/internal/util/mock/grpcclient.go index 30effd8bf1..21a34d2c26 100644 --- a/internal/util/mock/grpcclient.go +++ b/internal/util/mock/grpcclient.go @@ -26,71 +26,72 @@ import ( "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/generic" "github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/trace" ) -type GRPCClientBase struct { +type GRPCClientBase[T any] struct { getAddrFunc func() (string, error) - newGrpcClient func(cc *grpc.ClientConn) interface{} + newGrpcClient func(cc *grpc.ClientConn) T - grpcClient interface{} + grpcClient T conn *grpc.ClientConn grpcClientMtx sync.RWMutex GetGrpcClientErr error role string } -func (c *GRPCClientBase) SetGetAddrFunc(f func() (string, error)) { +func (c *GRPCClientBase[T]) SetGetAddrFunc(f func() (string, error)) { c.getAddrFunc = f } -func (c *GRPCClientBase) GetRole() string { +func (c *GRPCClientBase[T]) GetRole() string { return c.role } -func (c *GRPCClientBase) SetRole(role string) { +func (c *GRPCClientBase[T]) SetRole(role string) { c.role = role } -func (c *GRPCClientBase) EnableEncryption() { +func (c *GRPCClientBase[T]) EnableEncryption() { } -func (c *GRPCClientBase) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) interface{}) { +func (c *GRPCClientBase[T]) SetNewGrpcClientFunc(f func(cc *grpc.ClientConn) T) { c.newGrpcClient = f } -func (c *GRPCClientBase) GetGrpcClient(ctx context.Context) (interface{}, error) { +func (c *GRPCClientBase[T]) GetGrpcClient(ctx context.Context) (T, error) { c.grpcClientMtx.RLock() defer c.grpcClientMtx.RUnlock() c.connect(ctx) return c.grpcClient, c.GetGrpcClientErr } -func (c *GRPCClientBase) resetConnection(client interface{}) { +func (c *GRPCClientBase[T]) resetConnection(client T) { c.grpcClientMtx.Lock() defer c.grpcClientMtx.Unlock() - if c.grpcClient == nil { + if generic.IsZero(c.grpcClient) { return } - if client != c.grpcClient { + if !generic.Equal(client, c.grpcClient) { return } if c.conn != nil { _ = c.conn.Close() } c.conn = nil - c.grpcClient = nil + c.grpcClient = generic.Zero[T]() } -func (c *GRPCClientBase) connect(ctx context.Context, retryOptions ...retry.Option) error { +func (c *GRPCClientBase[T]) connect(ctx context.Context, retryOptions ...retry.Option) error { c.grpcClient = c.newGrpcClient(c.conn) return nil } -func (c *GRPCClientBase) callOnce(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { +func (c *GRPCClientBase[T]) callOnce(ctx context.Context, caller func(client T) (any, error)) (any, error) { client, err := c.GetGrpcClient(ctx) if err != nil { return nil, err @@ -108,7 +109,7 @@ func (c *GRPCClientBase) callOnce(ctx context.Context, caller func(client interf return ret, err2 } -func (c *GRPCClientBase) Call(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { +func (c *GRPCClientBase[T]) Call(ctx context.Context, caller func(client T) (any, error)) (any, error) { if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() } @@ -116,13 +117,13 @@ func (c *GRPCClientBase) Call(ctx context.Context, caller func(client interface{ ret, err := c.callOnce(ctx, caller) if err != nil { traceErr := fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - log.Error("GRPCClientBase Call grpc first call get error ", zap.Error(traceErr)) + log.Error("GRPCClientBase[T] Call grpc first call get error ", zap.Error(traceErr)) return nil, traceErr } return ret, err } -func (c *GRPCClientBase) ReCall(ctx context.Context, caller func(client interface{}) (interface{}, error)) (interface{}, error) { +func (c *GRPCClientBase[T]) ReCall(ctx context.Context, caller func(client T) (any, error)) (any, error) { // omit ctx check in mock first time to let each function has failed context ret, err := c.callOnce(ctx, caller) if err == nil { @@ -130,7 +131,7 @@ func (c *GRPCClientBase) ReCall(ctx context.Context, caller func(client interfac } traceErr := fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - log.Warn("GRPCClientBase client grpc first call get error ", zap.Error(traceErr)) + log.Warn("GRPCClientBase[T] client grpc first call get error ", zap.Error(traceErr)) if !funcutil.CheckCtxValid(ctx) { return nil, ctx.Err() @@ -139,13 +140,13 @@ func (c *GRPCClientBase) ReCall(ctx context.Context, caller func(client interfac ret, err = c.callOnce(ctx, caller) if err != nil { traceErr = fmt.Errorf("err: %s\n, %s", err.Error(), trace.StackTrace()) - log.Error("GRPCClientBase client grpc second call get error ", zap.Error(traceErr)) + log.Error("GRPCClientBase[T] client grpc second call get error ", zap.Error(traceErr)) return nil, traceErr } return ret, err } -func (c *GRPCClientBase) Close() error { +func (c *GRPCClientBase[T]) Close() error { c.grpcClientMtx.Lock() defer c.grpcClientMtx.Unlock() if c.conn != nil { diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index e97169e96e..3eaea4bf7f 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -1666,8 +1666,8 @@ class TestUtilityAdvanced(TestcaseBase): dst_node_ids = all_querynodes[1:] # add segment ids which are not exist sealed_segment_ids = [sealed_segment_id - for sealed_segment_id in range(max(segment_distribution[src_node_id]["sealed"]) + 1, - max(segment_distribution[src_node_id]["sealed"]) + 3)] + for sealed_segment_id in range(max(segment_distribution[src_node_id]["sealed"]) + 100, + max(segment_distribution[src_node_id]["sealed"]) + 103)] # load balance self.utility_wrap.load_balance(collection_w.name, src_node_id, dst_node_ids, sealed_segment_ids, check_task=CheckTasks.err_res,