milvus/master/client/client.go
bigsheeper bb9c906ef6 Use go mod instead of GO_PATH and add more cgo interfeces
Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
2020-09-01 16:23:39 +08:00

406 lines
10 KiB
Go

// Copyright 2016 TiKV Project Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package pd
import (
"context"
"strings"
"sync"
"time"
"github.com/opentracing/opentracing-go"
"github.com/czs007/suvlim/errors"
"github.com/czs007/suvlim/pkg/pdpb"
"github.com/pingcap/log"
"go.uber.org/zap"
)
// Client is a PD (Placement Driver) client.
// It should not be used after calling Close().
type Client interface {
// GetClusterID gets the cluster ID from PD.
GetClusterID(ctx context.Context) uint64
// GetMemberInfo gets the members Info from PD
//GetMemberInfo(ctx context.Context) ([]*pdpb.Member, error)
// GetLeaderAddr returns current leader's address. It returns "" before
// syncing leader from server.
GetLeaderAddr() string
// GetTS gets a timestamp from PD.
GetTS(ctx context.Context) (int64, int64, error)
// GetTSAsync gets a timestamp from PD, without block the caller.
GetTSAsync(ctx context.Context) TSFuture
Close()
}
type tsoRequest struct {
start time.Time
ctx context.Context
done chan error
physical int64
logical int64
}
const (
defaultPDTimeout = 3 * time.Second
dialTimeout = 3 * time.Second
updateLeaderTimeout = time.Second // Use a shorter timeout to recover faster from network isolation.
maxMergeTSORequests = 10000 // should be higher if client is sending requests in burst
maxInitClusterRetries = 100
)
var (
// errFailInitClusterID is returned when failed to load clusterID from all supplied PD addresses.
errFailInitClusterID = errors.New("[pd] failed to get cluster id")
// errClosing is returned when request is canceled when client is closing.
errClosing = errors.New("[pd] closing")
// errTSOLength is returned when the number of response timestamps is inconsistent with request.
errTSOLength = errors.New("[pd] tso length in rpc response is incorrect")
)
type client struct {
*baseClient
tsoRequests chan *tsoRequest
lastPhysical int64
lastLogical int64
tsDeadlineCh chan deadline
}
// NewClient creates a PD client.
func NewClient(pdAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) {
return NewClientWithContext(context.Background(), pdAddrs, security, opts...)
}
// NewClientWithContext creates a PD client with context.
func NewClientWithContext(ctx context.Context, pdAddrs []string, security SecurityOption, opts ...ClientOption) (Client, error) {
log.Info("[pd] create pd client with endpoints", zap.Strings("pd-address", pdAddrs))
base, err := newBaseClient(ctx, addrsToUrls(pdAddrs), security, opts...)
if err != nil {
return nil, err
}
c := &client{
baseClient: base,
tsoRequests: make(chan *tsoRequest, maxMergeTSORequests),
tsDeadlineCh: make(chan deadline, 1),
}
c.wg.Add(2)
go c.tsLoop()
go c.tsCancelLoop()
return c, nil
}
type deadline struct {
timer <-chan time.Time
done chan struct{}
cancel context.CancelFunc
}
func (c *client) tsCancelLoop() {
defer c.wg.Done()
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
for {
select {
case d := <-c.tsDeadlineCh:
select {
case <-d.timer:
log.Error("tso request is canceled due to timeout")
d.cancel()
case <-d.done:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}
func (c *client) checkStreamTimeout(loopCtx context.Context, cancel context.CancelFunc, createdCh chan struct{}) {
select {
case <-time.After(c.timeout):
cancel()
case <-createdCh:
return
case <-loopCtx.Done():
return
}
}
func (c *client) tsLoop() {
defer c.wg.Done()
loopCtx, loopCancel := context.WithCancel(c.ctx)
defer loopCancel()
defaultSize := maxMergeTSORequests + 1
requests := make([]*tsoRequest, defaultSize)
createdCh := make(chan struct{})
var opts []opentracing.StartSpanOption
var stream pdpb.PD_TsoClient
var cancel context.CancelFunc
for {
var err error
if stream == nil {
var ctx context.Context
ctx, cancel = context.WithCancel(loopCtx)
go c.checkStreamTimeout(loopCtx, cancel, createdCh)
stream, err = c.leaderClient().Tso(ctx)
if stream != nil {
createdCh <- struct{}{}
}
if err != nil {
select {
case <-loopCtx.Done():
cancel()
return
default:
}
log.Error("[pd] create tso stream error")
c.ScheduleCheckLeader()
cancel()
c.revokeTSORequest(errors.WithStack(err))
select {
case <-time.After(time.Second):
case <-loopCtx.Done():
return
}
continue
}
}
select {
case first := <-c.tsoRequests:
pendingPlus1 := len(c.tsoRequests) + 1
requests[0] = first
for i := 1; i < pendingPlus1; i++ {
requests[i] = <-c.tsoRequests
}
done := make(chan struct{})
dl := deadline{
timer: time.After(c.timeout),
done: done,
cancel: cancel,
}
select {
case c.tsDeadlineCh <- dl:
case <-loopCtx.Done():
cancel()
return
}
opts = extractSpanReference(requests[:pendingPlus1], opts[:0])
err = c.processTSORequests(stream, requests[:pendingPlus1], opts)
close(done)
case <-loopCtx.Done():
cancel()
return
}
if err != nil {
select {
case <-loopCtx.Done():
cancel()
return
default:
}
log.Error("[pd] getTS error")
c.ScheduleCheckLeader()
cancel()
stream, cancel = nil, nil
}
}
}
func extractSpanReference(requests []*tsoRequest, opts []opentracing.StartSpanOption) []opentracing.StartSpanOption {
for _, req := range requests {
if span := opentracing.SpanFromContext(req.ctx); span != nil {
opts = append(opts, opentracing.ChildOf(span.Context()))
}
}
return opts
}
func (c *client) processTSORequests(stream pdpb.PD_TsoClient, requests []*tsoRequest, opts []opentracing.StartSpanOption) error {
if len(opts) > 0 {
span := opentracing.StartSpan("pdclient.processTSORequests", opts...)
defer span.Finish()
}
count := len(requests)
//start := time.Now()
req := &pdpb.TsoRequest{
Header: c.requestHeader(),
Count: uint32(count),
}
if err := stream.Send(req); err != nil {
err = errors.WithStack(err)
c.finishTSORequest(requests, 0, 0, err)
return err
}
resp, err := stream.Recv()
if err != nil {
err = errors.WithStack(err)
c.finishTSORequest(requests, 0, 0, err)
return err
}
if resp.GetCount() != uint32(len(requests)) {
err = errors.WithStack(errTSOLength)
c.finishTSORequest(requests, 0, 0, err)
return err
}
physical, logical := resp.GetTimestamp().GetPhysical(), resp.GetTimestamp().GetLogical()
// Server returns the highest ts.
logical -= int64(resp.GetCount() - 1)
if tsLessEqual(physical, logical, c.lastPhysical, c.lastLogical) {
panic(errors.Errorf("timestamp fallback, newly acquired ts (%d,%d) is less or equal to last one (%d, %d)",
physical, logical, c.lastLogical, c.lastLogical))
}
c.lastPhysical = physical
c.lastLogical = logical + int64(len(requests)) - 1
c.finishTSORequest(requests, physical, logical, nil)
return nil
}
func tsLessEqual(physical, logical, thatPhysical, thatLogical int64) bool {
if physical == thatPhysical {
return logical <= thatLogical
}
return physical < thatPhysical
}
func (c *client) finishTSORequest(requests []*tsoRequest, physical, firstLogical int64, err error) {
for i := 0; i < len(requests); i++ {
if span := opentracing.SpanFromContext(requests[i].ctx); span != nil {
span.Finish()
}
requests[i].physical, requests[i].logical = physical, firstLogical+int64(i)
requests[i].done <- err
}
}
func (c *client) revokeTSORequest(err error) {
n := len(c.tsoRequests)
for i := 0; i < n; i++ {
req := <-c.tsoRequests
req.done <- err
}
}
func (c *client) Close() {
c.cancel()
c.wg.Wait()
c.revokeTSORequest(errors.WithStack(errClosing))
c.connMu.Lock()
defer c.connMu.Unlock()
for _, cc := range c.connMu.clientConns {
if err := cc.Close(); err != nil {
log.Error("[pd] failed to close gRPC clientConn")
}
}
}
// leaderClient gets the client of current PD leader.
func (c *client) leaderClient() pdpb.PDClient {
c.connMu.RLock()
defer c.connMu.RUnlock()
return pdpb.NewPDClient(c.connMu.clientConns[c.connMu.leader])
}
var tsoReqPool = sync.Pool{
New: func() interface{} {
return &tsoRequest{
done: make(chan error, 1),
physical: 0,
logical: 0,
}
},
}
func (c *client) GetTSAsync(ctx context.Context) TSFuture {
if span := opentracing.SpanFromContext(ctx); span != nil {
span = opentracing.StartSpan("GetTSAsync", opentracing.ChildOf(span.Context()))
ctx = opentracing.ContextWithSpan(ctx, span)
}
req := tsoReqPool.Get().(*tsoRequest)
req.ctx = ctx
req.start = time.Now()
c.tsoRequests <- req
return req
}
// TSFuture is a future which promises to return a TSO.
type TSFuture interface {
// Wait gets the physical and logical time, it would block caller if data is not available yet.
Wait() (int64, int64, error)
}
func (req *tsoRequest) Wait() (physical int64, logical int64, err error) {
// If tso command duration is observed very high, the reason could be it
// takes too long for Wait() be called.
select {
case err = <-req.done:
err = errors.WithStack(err)
defer tsoReqPool.Put(req)
if err != nil {
return 0, 0, err
}
physical, logical = req.physical, req.logical
return
case <-req.ctx.Done():
return 0, 0, errors.WithStack(req.ctx.Err())
}
}
func (c *client) GetTS(ctx context.Context) (physical int64, logical int64, err error) {
resp := c.GetTSAsync(ctx)
return resp.Wait()
}
func (c *client) requestHeader() *pdpb.RequestHeader {
return &pdpb.RequestHeader{
ClusterId: c.clusterID,
}
}
func addrsToUrls(addrs []string) []string {
// Add default schema "http://" to addrs.
urls := make([]string, 0, len(addrs))
for _, addr := range addrs {
if strings.Contains(addr, "://") {
urls = append(urls, addr)
} else {
urls = append(urls, "http://"+addr)
}
}
return urls
}