diff --git a/internal/distributed/datanode/datanode_test.go b/internal/distributed/datanode/datanode_test.go index fb27fec445..e403056333 100644 --- a/internal/distributed/datanode/datanode_test.go +++ b/internal/distributed/datanode/datanode_test.go @@ -18,6 +18,7 @@ import ( "net" "strconv" "testing" + "time" "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" @@ -124,7 +125,7 @@ func TestRun(t *testing.T) { dnServer.newMasterServiceClient = func(s string) (types.MasterService, error) { return &mockMaster{}, nil } - dnServer.newDataServiceClient = func(s string) types.DataService { + dnServer.newDataServiceClient = func(s, etcdAddress string, timeout time.Duration) types.DataService { return &mockDataService{} } diff --git a/internal/distributed/datanode/service.go b/internal/distributed/datanode/service.go index 30097880ce..1c617268f5 100644 --- a/internal/distributed/datanode/service.go +++ b/internal/distributed/datanode/service.go @@ -56,7 +56,7 @@ type Server struct { dataService types.DataService newMasterServiceClient func(string) (types.MasterService, error) - newDataServiceClient func(string) types.DataService + newDataServiceClient func(string, string, time.Duration) types.DataService closer io.Closer } @@ -72,8 +72,8 @@ func NewServer(ctx context.Context, factory msgstream.Factory) (*Server, error) newMasterServiceClient: func(s string) (types.MasterService, error) { return msc.NewClient(s, []string{dn.Params.EtcdAddress}, 20*time.Second) }, - newDataServiceClient: func(s string) types.DataService { - return dsc.NewClient(Params.DataServiceAddress) + newDataServiceClient: func(s, etcdAddress string, timeout time.Duration) types.DataService { + return dsc.NewClient(Params.DataServiceAddress, []string{etcdAddress}, timeout) }, } @@ -205,7 +205,7 @@ func (s *Server) init() error { if s.newDataServiceClient != nil { log.Debug("Data service address", zap.String("address", Params.DataServiceAddress)) log.Debug("DataNode Init data service client ...") - dataServiceClient := s.newDataServiceClient(Params.DataServiceAddress) + dataServiceClient := s.newDataServiceClient(Params.DataServiceAddress, dn.Params.EtcdAddress, 10) if err = dataServiceClient.Init(); err != nil { panic(err) } diff --git a/internal/distributed/dataservice/client/client.go b/internal/distributed/dataservice/client/client.go index 3a47c779a3..45cc3a8538 100644 --- a/internal/distributed/dataservice/client/client.go +++ b/internal/distributed/dataservice/client/client.go @@ -13,11 +13,14 @@ package grpcdataserviceclient import ( "context" + "fmt" "time" "github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/util/retry" + "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/internal/util/typeutil" otgrpc "github.com/opentracing-contrib/go-grpc" "github.com/opentracing/opentracing-go" "go.uber.org/zap" @@ -33,19 +36,83 @@ type Client struct { conn *grpc.ClientConn ctx context.Context addr string + + sess *sessionutil.Session + timeout time.Duration + recallTry int + reconnTry int } -func NewClient(addr string) *Client { +func getDataServiceAddress(sess *sessionutil.Session) (string, error) { + key := typeutil.DataServiceRole + msess, _, err := sess.GetSessions(key) + if err != nil { + return "", err + } + ms, ok := msess[key] + if !ok { + return "", fmt.Errorf("number of master service is incorrect, %d", len(msess)) + } + return ms.Address, nil +} + +func NewClient(address string, etcdAddr []string, timeout time.Duration) *Client { + sess := sessionutil.NewSession(context.Background(), etcdAddr) return &Client{ - addr: addr, - ctx: context.Background(), + addr: address, + ctx: context.Background(), + sess: sess, + timeout: timeout, + recallTry: 3, + reconnTry: 10, } } func (c *Client) Init() error { tracer := opentracing.GlobalTracer() + if c.addr != "" { + connectGrpcFunc := func() error { + log.Debug("dataservice connect ", zap.String("address", c.addr)) + conn, err := grpc.DialContext(c.ctx, c.addr, grpc.WithInsecure(), grpc.WithBlock(), + grpc.WithUnaryInterceptor( + otgrpc.OpenTracingClientInterceptor(tracer)), + grpc.WithStreamInterceptor( + otgrpc.OpenTracingStreamClientInterceptor(tracer))) + if err != nil { + return err + } + c.conn = conn + return nil + } + + err := retry.Retry(100000, time.Millisecond*200, connectGrpcFunc) + if err != nil { + return err + } + } else { + return c.reconnect() + } + c.grpcClient = datapb.NewDataServiceClient(c.conn) + + return nil +} + +func (c *Client) reconnect() error { + tracer := opentracing.GlobalTracer() + var err error + getDataServiceAddressFn := func() error { + c.addr, err = getDataServiceAddress(c.sess) + if err != nil { + return err + } + return nil + } + err = retry.Retry(c.reconnTry, 3*time.Second, getDataServiceAddressFn) + if err != nil { + return err + } connectGrpcFunc := func() error { - log.Debug("dataservice connect ", zap.String("address", c.addr)) + log.Debug("DataService connect ", zap.String("address", c.addr)) conn, err := grpc.DialContext(c.ctx, c.addr, grpc.WithInsecure(), grpc.WithBlock(), grpc.WithUnaryInterceptor( otgrpc.OpenTracingClientInterceptor(tracer)), @@ -58,15 +125,31 @@ func (c *Client) Init() error { return nil } - err := retry.Retry(100000, time.Millisecond*200, connectGrpcFunc) + err = retry.Retry(c.reconnTry, 500*time.Millisecond, connectGrpcFunc) if err != nil { return err } c.grpcClient = datapb.NewDataServiceClient(c.conn) - return nil } +func (c *Client) recall(caller func() (interface{}, error)) (interface{}, error) { + ret, err := caller() + if err == nil { + return ret, nil + } + for i := 0; i < c.recallTry; i++ { + err = c.reconnect() + if err == nil { + ret, err = caller() + if err == nil { + return ret, nil + } + } + } + return ret, err +} + func (c *Client) Start() error { return nil } @@ -81,57 +164,99 @@ func (c *Client) Register() error { } func (c *Client) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { - return c.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetComponentStates(ctx, &internalpb.GetComponentStatesRequest{}) + }) + return ret.(*internalpb.ComponentStates), err } func (c *Client) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetTimeTickChannel(ctx, &internalpb.GetTimeTickChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } func (c *Client) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetStatisticsChannel(ctx, &internalpb.GetStatisticsChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } func (c *Client) RegisterNode(ctx context.Context, req *datapb.RegisterNodeRequest) (*datapb.RegisterNodeResponse, error) { - return c.grpcClient.RegisterNode(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.RegisterNode(ctx, req) + }) + return ret.(*datapb.RegisterNodeResponse), err } func (c *Client) Flush(ctx context.Context, req *datapb.FlushRequest) (*commonpb.Status, error) { - return c.grpcClient.Flush(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.Flush(ctx, req) + }) + return ret.(*commonpb.Status), err } func (c *Client) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentIDRequest) (*datapb.AssignSegmentIDResponse, error) { - return c.grpcClient.AssignSegmentID(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.AssignSegmentID(ctx, req) + }) + return ret.(*datapb.AssignSegmentIDResponse), err } func (c *Client) ShowSegments(ctx context.Context, req *datapb.ShowSegmentsRequest) (*datapb.ShowSegmentsResponse, error) { - return c.grpcClient.ShowSegments(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.ShowSegments(ctx, req) + }) + return ret.(*datapb.ShowSegmentsResponse), err } func (c *Client) GetSegmentStates(ctx context.Context, req *datapb.GetSegmentStatesRequest) (*datapb.GetSegmentStatesResponse, error) { - return c.grpcClient.GetSegmentStates(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetSegmentStates(ctx, req) + }) + return ret.(*datapb.GetSegmentStatesResponse), err } func (c *Client) GetInsertBinlogPaths(ctx context.Context, req *datapb.GetInsertBinlogPathsRequest) (*datapb.GetInsertBinlogPathsResponse, error) { - return c.grpcClient.GetInsertBinlogPaths(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetInsertBinlogPaths(ctx, req) + }) + return ret.(*datapb.GetInsertBinlogPathsResponse), err } func (c *Client) GetInsertChannels(ctx context.Context, req *datapb.GetInsertChannelsRequest) (*internalpb.StringList, error) { - return c.grpcClient.GetInsertChannels(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetInsertChannels(ctx, req) + }) + return ret.(*internalpb.StringList), err } func (c *Client) GetCollectionStatistics(ctx context.Context, req *datapb.GetCollectionStatisticsRequest) (*datapb.GetCollectionStatisticsResponse, error) { - return c.grpcClient.GetCollectionStatistics(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetCollectionStatistics(ctx, req) + }) + return ret.(*datapb.GetCollectionStatisticsResponse), err } func (c *Client) GetPartitionStatistics(ctx context.Context, req *datapb.GetPartitionStatisticsRequest) (*datapb.GetPartitionStatisticsResponse, error) { - return c.grpcClient.GetPartitionStatistics(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetPartitionStatistics(ctx, req) + }) + return ret.(*datapb.GetPartitionStatisticsResponse), err } func (c *Client) GetSegmentInfoChannel(ctx context.Context) (*milvuspb.StringResponse, error) { - return c.grpcClient.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetSegmentInfoChannel(ctx, &datapb.GetSegmentInfoChannelRequest{}) + }) + return ret.(*milvuspb.StringResponse), err } func (c *Client) GetSegmentInfo(ctx context.Context, req *datapb.GetSegmentInfoRequest) (*datapb.GetSegmentInfoResponse, error) { - return c.grpcClient.GetSegmentInfo(ctx, req) + ret, err := c.recall(func() (interface{}, error) { + return c.grpcClient.GetSegmentInfo(ctx, req) + }) + return ret.(*datapb.GetSegmentInfoResponse), err } diff --git a/internal/distributed/masterservice/masterservice_test.go b/internal/distributed/masterservice/masterservice_test.go index 7e6d409e90..4858d96e02 100644 --- a/internal/distributed/masterservice/masterservice_test.go +++ b/internal/distributed/masterservice/masterservice_test.go @@ -925,7 +925,7 @@ func TestRun(t *testing.T) { svr.newProxyServiceClient = func(s string) types.ProxyService { return &mockProxy{} } - svr.newDataServiceClient = func(s string) types.DataService { + svr.newDataServiceClient = func(s, address string, timeout time.Duration) types.DataService { return &mockDataService{} } svr.newIndexServiceClient = func(s string) types.IndexService { diff --git a/internal/distributed/masterservice/server.go b/internal/distributed/masterservice/server.go index 22accca6d3..0ddde581f7 100644 --- a/internal/distributed/masterservice/server.go +++ b/internal/distributed/masterservice/server.go @@ -59,7 +59,7 @@ type Server struct { queryService types.QueryService newProxyServiceClient func(string) types.ProxyService - newDataServiceClient func(string) types.DataService + newDataServiceClient func(string, string, time.Duration) types.DataService newIndexServiceClient func(string) types.IndexService newQueryServiceClient func(string) (types.QueryService, error) @@ -98,8 +98,8 @@ func (s *Server) setClient() { } return psClient } - s.newDataServiceClient = func(s string) types.DataService { - dsClient := dsc.NewClient(s) + s.newDataServiceClient = func(s, etcdAddress string, timeout time.Duration) types.DataService { + dsClient := dsc.NewClient(s, []string{etcdAddress}, timeout) if err := dsClient.Init(); err != nil { panic(err) } @@ -183,7 +183,7 @@ func (s *Server) init() error { } if s.newDataServiceClient != nil { log.Debug("data service", zap.String("address", Params.DataServiceAddress)) - dataService := s.newDataServiceClient(Params.DataServiceAddress) + dataService := s.newDataServiceClient(Params.DataServiceAddress, cms.Params.EtcdAddress, 10) if err := s.masterService.SetDataService(ctx, dataService); err != nil { panic(err) } diff --git a/internal/distributed/proxynode/service.go b/internal/distributed/proxynode/service.go index d028a0e16a..359fc0a850 100644 --- a/internal/distributed/proxynode/service.go +++ b/internal/distributed/proxynode/service.go @@ -207,7 +207,7 @@ func (s *Server) init() error { dataServiceAddr := Params.DataServiceAddress log.Debug("proxynode", zap.String("data service address", dataServiceAddr)) - s.dataServiceClient = grpcdataserviceclient.NewClient(dataServiceAddr) + s.dataServiceClient = grpcdataserviceclient.NewClient(dataServiceAddr, []string{proxynode.Params.EtcdAddress}, 10) err = s.dataServiceClient.Init() if err != nil { return err diff --git a/internal/distributed/querynode/service.go b/internal/distributed/querynode/service.go index ad8ec40ac2..845c552fd2 100644 --- a/internal/distributed/querynode/service.go +++ b/internal/distributed/querynode/service.go @@ -181,7 +181,7 @@ func (s *Server) init() error { log.Debug("Data service", zap.String("address", Params.DataServiceAddress)) log.Debug("QueryNode Init data service client ...") - dataService := dsc.NewClient(Params.DataServiceAddress) + dataService := dsc.NewClient(Params.DataServiceAddress, []string{qn.Params.EtcdAddress}, 10) if err = dataService.Init(); err != nil { panic(err) } diff --git a/internal/distributed/queryservice/service.go b/internal/distributed/queryservice/service.go index 6c8788b8f3..d61e6f1241 100644 --- a/internal/distributed/queryservice/service.go +++ b/internal/distributed/queryservice/service.go @@ -138,7 +138,7 @@ func (s *Server) init() error { log.Debug("DataService", zap.String("Address", Params.DataServiceAddress)) log.Debug("QueryService Init data service client ...") - dataService := dsc.NewClient(Params.DataServiceAddress) + dataService := dsc.NewClient(Params.DataServiceAddress, []string{qs.Params.EtcdAddress}, 10) if err = dataService.Init(); err != nil { panic(err) }