From 9353dfb8439bb647f09e909be1722410dc94d26b Mon Sep 17 00:00:00 2001 From: congqixia Date: Thu, 18 Aug 2022 16:12:50 +0800 Subject: [PATCH] Simplify Dml-DeltaChannel mapping logic (#18708) Signed-off-by: Congqi Xia Signed-off-by: Congqi Xia --- internal/rootcoord/dml_channels.go | 182 +++++++++++++++++++----- internal/rootcoord/dml_channels_test.go | 112 ++++++++++++--- internal/rootcoord/root_coord_test.go | 19 +-- internal/rootcoord/task.go | 40 ++---- internal/rootcoord/timeticksync.go | 45 +----- 5 files changed, 269 insertions(+), 129 deletions(-) diff --git a/internal/rootcoord/dml_channels.go b/internal/rootcoord/dml_channels.go index d98fd69737..f22c673803 100644 --- a/internal/rootcoord/dml_channels.go +++ b/internal/rootcoord/dml_channels.go @@ -17,13 +17,13 @@ package rootcoord import ( + "container/heap" "context" "fmt" "sync" "github.com/milvus-io/milvus/internal/metrics" - "go.uber.org/atomic" "go.uber.org/zap" "github.com/milvus-io/milvus/internal/log" @@ -31,9 +31,102 @@ import ( ) type dmlMsgStream struct { - ms msgstream.MsgStream - mutex sync.RWMutex - refcnt int64 + ms msgstream.MsgStream + mutex sync.RWMutex + + refcnt int64 // current in use count + used int64 // total used counter in current run, not stored in meta so meant to be inaccurate + idx int64 // idx for name + pos int // position in the heap slice +} + +// RefCnt returns refcnt with mutex protection. +func (dms *dmlMsgStream) RefCnt() int64 { + dms.mutex.RLock() + defer dms.mutex.RUnlock() + return dms.refcnt +} + +// RefCnt returns refcnt with mutex protection. +func (dms *dmlMsgStream) Used() int64 { + dms.mutex.RLock() + defer dms.mutex.RUnlock() + return dms.used +} + +// IncRefcnt increases refcnt. +func (dms *dmlMsgStream) IncRefcnt() { + dms.mutex.Lock() + defer dms.mutex.Unlock() + dms.refcnt++ +} + +// BookUsage increases used, acting like reservation usage. +func (dms *dmlMsgStream) BookUsage() { + dms.mutex.Lock() + defer dms.mutex.Unlock() + dms.used++ +} + +// DecRefCnt decreases refcnt only. +func (dms *dmlMsgStream) DecRefCnt() { + dms.mutex.Lock() + defer dms.mutex.Unlock() + if dms.refcnt > 0 { + dms.refcnt-- + } else { + log.Warn("Try to remove channel with no ref count", zap.Int64("idx", dms.idx)) + } +} + +// channelsHeap implements heap.Interface to performs like an priority queue. +type channelsHeap []*dmlMsgStream + +// Len is the number of elements in the collection. +func (h channelsHeap) Len() int { + return len(h) +} + +// Less reports whether the element with index i +// must sort before the element with index j. +func (h channelsHeap) Less(i int, j int) bool { + ei, ej := h[i], h[j] + // use less refcnt first + rci, rcj := ei.RefCnt(), ej.RefCnt() + if rci != rcj { + return rci < rcj + } + + // used not used channel first + ui, uj := ei.Used(), ej.Used() + if ui != uj { + return ui < uj + } + + // all number same, used alphabetic smaller one + return ei.idx < ej.idx +} + +// Swap swaps the elements with indexes i and j. +func (h channelsHeap) Swap(i int, j int) { + h[i], h[j] = h[j], h[i] + h[i].pos, h[j].pos = i, j +} + +// Push adds a new element to the heap. +func (h *channelsHeap) Push(x interface{}) { + item := x.(*dmlMsgStream) + *h = append(*h, item) +} + +// Pop implements heap.Interface, pop the last value. +func (h *channelsHeap) Pop() interface{} { + old := *h + n := len(old) + item := old[n-1] + old[n-1] = nil + *h = old[0 : n-1] + return item } type dmlChannels struct { @@ -41,18 +134,21 @@ type dmlChannels struct { factory msgstream.Factory namePrefix string capacity int64 - idx *atomic.Int64 - pool sync.Map + // pool maintains channelName => dmlMsgStream mapping, stable + pool sync.Map + // mut protects channlsHeap only + mut sync.Mutex + // channelsHeap is the heap to pop next dms for use + channelsHeap channelsHeap } func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePrefix string, chanNum int64) *dmlChannels { d := &dmlChannels{ - ctx: ctx, - factory: factory, - namePrefix: chanNamePrefix, - capacity: chanNum, - idx: atomic.NewInt64(0), - pool: sync.Map{}, + ctx: ctx, + factory: factory, + namePrefix: chanNamePrefix, + capacity: chanNum, + channelsHeap: make([]*dmlMsgStream, 0, chanNum), } for i := int64(0); i < chanNum; i++ { @@ -63,12 +159,19 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref panic("Failed to add msgstream") } ms.AsProducer([]string{name}) - d.pool.Store(name, &dmlMsgStream{ + + dms := &dmlMsgStream{ ms: ms, - mutex: sync.RWMutex{}, refcnt: 0, - }) + used: 0, + idx: i, + pos: int(i), + } + d.pool.Store(name, dms) + d.channelsHeap = append(d.channelsHeap, dms) } + + heap.Init(&d.channelsHeap) log.Debug("init dml channels", zap.Int64("num", chanNum)) metrics.RootCoordNumOfDMLChannel.Add(float64(chanNum)) metrics.RootCoordNumOfMsgStream.Add(float64(chanNum)) @@ -76,21 +179,38 @@ func newDmlChannels(ctx context.Context, factory msgstream.Factory, chanNamePref return d } -func (d *dmlChannels) getChannelName() string { - cnt := d.idx.Inc() - return genChannelName(d.namePrefix, (cnt-1)%d.capacity) +func (d *dmlChannels) getChannelNames(count int) []string { + if count > len(d.channelsHeap) { + return nil + } + d.mut.Lock() + defer d.mut.Unlock() + // get next count items from heap + items := make([]*dmlMsgStream, 0, count) + result := make([]string, 0, count) + for i := 0; i < count; i++ { + item := heap.Pop(&d.channelsHeap).(*dmlMsgStream) + item.BookUsage() + items = append(items, item) + result = append(result, genChannelName(d.namePrefix, item.idx)) + } + + for _, item := range items { + heap.Push(&d.channelsHeap, item) + } + + return result } func (d *dmlChannels) listChannels() []string { var chanNames []string + d.pool.Range( func(k, v interface{}) bool { dms := v.(*dmlMsgStream) - dms.mutex.RLock() - if dms.refcnt > 0 { - chanNames = append(chanNames, k.(string)) + if dms.RefCnt() > 0 { + chanNames = append(chanNames, genChannelName(d.namePrefix, dms.idx)) } - dms.mutex.RUnlock() return true }) return chanNames @@ -161,9 +281,10 @@ func (d *dmlChannels) addChannels(names ...string) { } dms := v.(*dmlMsgStream) - dms.mutex.Lock() - dms.refcnt++ - dms.mutex.Unlock() + d.mut.Lock() + dms.IncRefcnt() + heap.Fix(&d.channelsHeap, dms.pos) + d.mut.Unlock() } } @@ -176,13 +297,10 @@ func (d *dmlChannels) removeChannels(names ...string) { } dms := v.(*dmlMsgStream) - dms.mutex.Lock() - if dms.refcnt > 0 { - dms.refcnt-- - } else { - log.Warn("Try to remove channel with no ref count", zap.String("channel name", name)) - } - dms.mutex.Unlock() + d.mut.Lock() + dms.DecRefCnt() + heap.Fix(&d.channelsHeap, dms.pos) + d.mut.Unlock() } } diff --git a/internal/rootcoord/dml_channels_test.go b/internal/rootcoord/dml_channels_test.go index 0086ae4643..679385c135 100644 --- a/internal/rootcoord/dml_channels_test.go +++ b/internal/rootcoord/dml_channels_test.go @@ -17,8 +17,10 @@ package rootcoord import ( + "container/heap" "context" "errors" + "math/rand" "sync" "testing" @@ -31,6 +33,93 @@ import ( "github.com/stretchr/testify/require" ) +func TestDmlMsgStream(t *testing.T) { + t.Run("RefCnt", func(t *testing.T) { + + dms := &dmlMsgStream{refcnt: 0} + assert.Equal(t, int64(0), dms.RefCnt()) + assert.Equal(t, int64(0), dms.Used()) + + dms.IncRefcnt() + assert.Equal(t, int64(1), dms.RefCnt()) + dms.BookUsage() + assert.Equal(t, int64(1), dms.Used()) + + dms.DecRefCnt() + assert.Equal(t, int64(0), dms.RefCnt()) + assert.Equal(t, int64(1), dms.Used()) + + dms.DecRefCnt() + assert.Equal(t, int64(0), dms.RefCnt()) + assert.Equal(t, int64(1), dms.Used()) + }) +} + +func TestChannelsHeap(t *testing.T) { + chanNum := 16 + var h channelsHeap + h = make([]*dmlMsgStream, 0, chanNum) + + for i := int64(0); i < int64(chanNum); i++ { + dms := &dmlMsgStream{ + refcnt: 0, + used: 0, + idx: i, + pos: int(i), + } + h = append(h, dms) + } + + check := func(h channelsHeap) bool { + for i := 0; i < chanNum; i++ { + if h[i].pos != i { + return false + } + if i*2+1 < chanNum { + if !h.Less(i, i*2+1) { + t.Log("left", i) + return false + } + } + if i*2+2 < chanNum { + if !h.Less(i, i*2+2) { + t.Log("right", i) + return false + } + } + } + return true + } + + heap.Init(&h) + + assert.True(t, check(h)) + + // add usage for all + for i := 0; i < chanNum; i++ { + h[0].BookUsage() + h[0].IncRefcnt() + heap.Fix(&h, 0) + } + + assert.True(t, check(h)) + for i := 0; i < chanNum; i++ { + assert.EqualValues(t, 1, h[i].RefCnt()) + assert.EqualValues(t, 1, h[i].Used()) + } + + randIdx := rand.Intn(chanNum) + + target := h[randIdx] + h[randIdx].DecRefCnt() + heap.Fix(&h, randIdx) + assert.EqualValues(t, 0, target.pos) + + next := heap.Pop(&h).(*dmlMsgStream) + + assert.Equal(t, target, next) +} + func TestDmlChannels(t *testing.T) { const ( dmlChanPrefix = "rootcoord-dml" @@ -52,27 +141,18 @@ func TestDmlChannels(t *testing.T) { assert.Panics(t, func() { dml.broadcastMark([]string{randStr}, nil) }) assert.Panics(t, func() { dml.removeChannels(randStr) }) - // dml_xxx_0 => {chanName0, chanName2} - // dml_xxx_1 => {chanName1} - chanName0 := dml.getChannelName() - dml.addChannels(chanName0) - assert.Equal(t, 1, dml.getChannelNum()) - - chanName1 := dml.getChannelName() - dml.addChannels(chanName1) + chans0 := dml.getChannelNames(2) + dml.addChannels(chans0...) assert.Equal(t, 2, dml.getChannelNum()) - chanName2 := dml.getChannelName() - dml.addChannels(chanName2) + chans1 := dml.getChannelNames(1) + dml.addChannels(chans1...) assert.Equal(t, 2, dml.getChannelNum()) - dml.removeChannels(chanName0) + dml.removeChannels(chans1...) assert.Equal(t, 2, dml.getChannelNum()) - dml.removeChannels(chanName1) - assert.Equal(t, 1, dml.getChannelNum()) - - dml.removeChannels(chanName0) + dml.removeChannels(chans0...) assert.Equal(t, 0, dml.getChannelNum()) } @@ -90,7 +170,7 @@ func TestDmChannelsFailure(t *testing.T) { defer wg.Done() mockFactory := &FailMessageStreamFactory{errBroadcast: true} dml := newDmlChannels(context.TODO(), mockFactory, "test-newdmlchannel-root", 1) - chanName0 := dml.getChannelName() + chanName0 := dml.getChannelNames(1)[0] dml.addChannels(chanName0) require.Equal(t, 1, dml.getChannelNum()) diff --git a/internal/rootcoord/root_coord_test.go b/internal/rootcoord/root_coord_test.go index 5dec5b02f0..9e0e3cc4a4 100644 --- a/internal/rootcoord/root_coord_test.go +++ b/internal/rootcoord/root_coord_test.go @@ -528,10 +528,9 @@ func createCollectionInMeta(dbName, collName string, core *Core, shardsNum int32 } vchanNames := make([]string, t.ShardsNum) - chanNames := make([]string, t.ShardsNum) + chanNames := core.chanTimeTick.getDmlChannelNames(int(t.ShardsNum)) for i := int32(0); i < t.ShardsNum; i++ { - vchanNames[i] = fmt.Sprintf("%s_%dv%d", core.chanTimeTick.getDmlChannelName(), collID, i) - chanNames[i] = funcutil.ToPhysicalChannel(vchanNames[i]) + vchanNames[i] = fmt.Sprintf("%s_%dv%d", chanNames[i], collID, i) } collInfo := etcdpb.CollectionInfo{ @@ -2335,15 +2334,11 @@ func TestRootCoord_Base(t *testing.T) { assert.NoError(t, err) time.Sleep(100 * time.Millisecond) - cn0 := core.chanTimeTick.getDmlChannelName() - cn1 := core.chanTimeTick.getDmlChannelName() - cn2 := core.chanTimeTick.getDmlChannelName() - core.chanTimeTick.addDmlChannels(cn0, cn1, cn2) - - dn0 := core.chanTimeTick.getDeltaChannelName() - dn1 := core.chanTimeTick.getDeltaChannelName() - dn2 := core.chanTimeTick.getDeltaChannelName() - core.chanTimeTick.addDeltaChannels(dn0, dn1, dn2) + cns := core.chanTimeTick.getDmlChannelNames(3) + cn0 := cns[0] + cn1 := cns[1] + cn2 := cns[2] + core.chanTimeTick.addDmlChannels(cns...) // wait for local channel reported for { diff --git a/internal/rootcoord/task.go b/internal/rootcoord/task.go index 0e60cb9b73..70924c873a 100644 --- a/internal/rootcoord/task.go +++ b/internal/rootcoord/task.go @@ -156,26 +156,18 @@ func (t *CreateCollectionReqTask) Execute(ctx context.Context) error { zap.Int64("default partition id", partID)) vchanNames := make([]string, t.Req.ShardsNum) - chanNames := make([]string, t.Req.ShardsNum) deltaChanNames := make([]string, t.Req.ShardsNum) - for i := int32(0); i < t.Req.ShardsNum; i++ { - vchanNames[i] = fmt.Sprintf("%s_%dv%d", t.core.chanTimeTick.getDmlChannelName(), collID, i) - chanNames[i] = funcutil.ToPhysicalChannel(vchanNames[i]) - deltaChanNames[i] = t.core.chanTimeTick.getDeltaChannelName() - deltaChanName, err1 := funcutil.ConvertChannelName(chanNames[i], Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta) - if err1 != nil || deltaChanName != deltaChanNames[i] { - err1Msg := "" - if err1 != nil { - err1Msg = err1.Error() - } - log.Debug("dmlChanName deltaChanName mismatch detail", zap.Int32("i", i), - zap.String("vchanName", vchanNames[i]), - zap.String("phsicalChanName", chanNames[i]), - zap.String("deltaChanName", deltaChanNames[i]), - zap.String("converted_deltaChanName", deltaChanName), - zap.String("err", err1Msg)) - return fmt.Errorf("dmlChanName %s and deltaChanName %s mis-match", chanNames[i], deltaChanNames[i]) + //physical channel names + chanNames := t.core.chanTimeTick.getDmlChannelNames(int(t.Req.ShardsNum)) + for i := int32(0); i < t.Req.ShardsNum; i++ { + vchanNames[i] = fmt.Sprintf("%s_%dv%d", chanNames[i], collID, i) + deltaChanNames[i], err = funcutil.ConvertChannelName(chanNames[i], Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta) + if err != nil { + log.Warn("failed to generate delta channel name", + zap.String("dmlChannelName", chanNames[i]), + zap.Error(err)) + return fmt.Errorf("failed to generate delta channel name from %s, %w", chanNames[i], err) } } @@ -241,9 +233,6 @@ func (t *CreateCollectionReqTask) Execute(ctx context.Context) error { // add dml channel before send dd msg t.core.chanTimeTick.addDmlChannels(chanNames...) - // also add delta channels - t.core.chanTimeTick.addDeltaChannels(deltaChanNames...) - ids, err := t.core.SendDdCreateCollectionReq(ctx, &ddCollReq, chanNames) if err != nil { return fmt.Errorf("send dd create collection req failed, error = %w", err) @@ -258,7 +247,6 @@ func (t *CreateCollectionReqTask) Execute(ctx context.Context) error { // update meta table after send dd operation if err = t.core.MetaTable.AddCollection(&collInfo, ts, idxInfo, ddOpStr); err != nil { t.core.chanTimeTick.removeDmlChannels(chanNames...) - t.core.chanTimeTick.removeDeltaChannels(deltaChanNames...) // it's ok just to leave create collection message sent, datanode and querynode does't process CreateCollection logic return fmt.Errorf("meta table add collection failed,error = %w", err) } @@ -385,14 +373,6 @@ func (t *DropCollectionReqTask) Execute(ctx context.Context) error { // remove dml channel after send dd msg t.core.chanTimeTick.removeDmlChannels(collMeta.PhysicalChannelNames...) - // remove delta channels - deltaChanNames := make([]string, len(collMeta.PhysicalChannelNames)) - for i, chanName := range collMeta.PhysicalChannelNames { - if deltaChanNames[i], err = funcutil.ConvertChannelName(chanName, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta); err != nil { - return err - } - } - t.core.chanTimeTick.removeDeltaChannels(deltaChanNames...) return nil } diff --git a/internal/rootcoord/timeticksync.go b/internal/rootcoord/timeticksync.go index f3e438afb8..9fd206095b 100644 --- a/internal/rootcoord/timeticksync.go +++ b/internal/rootcoord/timeticksync.go @@ -28,7 +28,6 @@ import ( "github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" - "github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/timerecord" "github.com/milvus-io/milvus/internal/util/tsoutil" @@ -48,8 +47,7 @@ type timetickSync struct { ctx context.Context sourceID typeutil.UniqueID - dmlChannels *dmlChannels // used for insert - deltaChannels *dmlChannels // used for delete + dmlChannels *dmlChannels // used for insert lock sync.Mutex sess2ChanTsMap map[typeutil.UniqueID]*chanTsMsg @@ -89,33 +87,18 @@ func (c *chanTsMsg) getTimetick(channelName string) typeutil.Timestamp { func newTimeTickSync(ctx context.Context, sourceID int64, factory msgstream.Factory, chanMap map[typeutil.UniqueID][]string) *timetickSync { // initialize dml channels used for insert dmlChannels := newDmlChannels(ctx, factory, Params.CommonCfg.RootCoordDml, Params.RootCoordCfg.DmlChannelNum) - // initialize delta channels used for delete, share Params.DmlChannelNum with dmlChannels - deltaChannels := newDmlChannels(ctx, factory, Params.CommonCfg.RootCoordDelta, Params.RootCoordCfg.DmlChannelNum) // recover physical channels for all collections for collID, chanNames := range chanMap { dmlChannels.addChannels(chanNames...) - log.Debug("recover physical channels", zap.Int64("collID", collID), zap.Any("physical channels", chanNames)) - - var err error - deltaChanNames := make([]string, len(chanNames)) - for i, chanName := range chanNames { - deltaChanNames[i], err = funcutil.ConvertChannelName(chanName, Params.CommonCfg.RootCoordDml, Params.CommonCfg.RootCoordDelta) - if err != nil { - log.Error("failed to convert dml channel name to delta channel name", zap.String("chanName", chanName)) - panic("invalid dml channel name " + chanName) - } - } - deltaChannels.addChannels(deltaChanNames...) - log.Debug("recover delta channels", zap.Int64("collID", collID), zap.Any("delta channels", deltaChanNames)) + log.Debug("recover physical channels", zap.Int64("collID", collID), zap.Strings("physical channels", chanNames)) } return &timetickSync{ ctx: ctx, sourceID: sourceID, - dmlChannels: dmlChannels, - deltaChannels: deltaChannels, + dmlChannels: dmlChannels, lock: sync.Mutex{}, sess2ChanTsMap: make(map[typeutil.UniqueID]*chanTsMsg), @@ -384,9 +367,9 @@ func (t *timetickSync) getSessionNum() int { } /////////////////////////////////////////////////////////////////////////////// -// GetDmlChannelName return a valid dml channel name -func (t *timetickSync) getDmlChannelName() string { - return t.dmlChannels.getChannelName() +// getDmlChannelNames returns list of channel names. +func (t *timetickSync) getDmlChannelNames(count int) []string { + return t.dmlChannels.getChannelNames(count) } // GetDmlChannelNum return the num of dml channels @@ -419,22 +402,6 @@ func (t *timetickSync) broadcastMarkDmlChannels(chanNames []string, pack *msgstr return t.dmlChannels.broadcastMark(chanNames, pack) } -/////////////////////////////////////////////////////////////////////////////// -// GetDeltaChannelName return a valid delta channel name -func (t *timetickSync) getDeltaChannelName() string { - return t.deltaChannels.getChannelName() -} - -// AddDeltaChannels add delta channels -func (t *timetickSync) addDeltaChannels(names ...string) { - t.deltaChannels.addChannels(names...) -} - -// RemoveDeltaChannels remove delta channels -func (t *timetickSync) removeDeltaChannels(names ...string) { - t.deltaChannels.removeChannels(names...) -} - func minTimeTick(tt ...typeutil.Timestamp) typeutil.Timestamp { var ret typeutil.Timestamp for _, t := range tt {