From 73b6c132454828d40483066a3fbd4aed3e735cb3 Mon Sep 17 00:00:00 2001 From: congqixia Date: Mon, 6 Feb 2023 15:27:54 +0800 Subject: [PATCH] Expand singleflight to Get and make writeEvent sync (#21993) Signed-off-by: Congqi Xia --- internal/util/cache/local_cache.go | 96 +++++++++++++------------ internal/util/cache/local_cache_test.go | 2 +- internal/util/cache/policy.go | 8 +++ 3 files changed, 61 insertions(+), 45 deletions(-) diff --git a/internal/util/cache/local_cache.go b/internal/util/cache/local_cache.go index 9cbdce1957..a2a7705c9c 100644 --- a/internal/util/cache/local_cache.go +++ b/internal/util/cache/local_cache.go @@ -117,7 +117,7 @@ func (c *localCache[K, V]) init() { func (c *localCache[K, V]) Close() error { if atomic.CompareAndSwapInt32(&c.closing, 0, 1) { // Do not close events channel to avoid panic when cache is still being used. - c.events <- entryEvent{nil, eventClose} + c.events <- entryEvent{nil, eventClose, make(chan struct{})} // Wait for the goroutine to close this channel c.closeWG.Wait() } @@ -167,7 +167,7 @@ func (c *localCache[K, V]) Put(k K, v V) { en.setValue(v) en.setWriteTime(now.UnixNano()) } - c.sendEvent(eventWrite, en) + <-c.sendEvent(eventWrite, en) } // Invalidate removes the entry associated with key k. @@ -204,28 +204,36 @@ func (c *localCache[K, V]) Scan(filter func(K, V) bool) map[K]V { // if it is not in the cache. The returned value is only cached when loader returns // nil error. func (c *localCache[K, V]) Get(k K) (V, error) { - en := c.cache.get(k, sum(k)) - if en == nil { - c.stats.RecordMisses(1) - return c.load(k) - } - // Check if this entry needs to be refreshed - now := currentTime() - if c.isExpired(en, now) { - c.stats.RecordMisses(1) - if c.loader == nil { - c.sendEvent(eventDelete, en) - } else { - // Update value if expired - c.setEntryAccessTime(en, now) - c.refresh(en) + val, err, _ := c.singleflight.Do(fmt.Sprintf("%v", k), func() (any, error) { + en := c.cache.get(k, sum(k)) + if en == nil { + c.stats.RecordMisses(1) + return c.load(k) } - } else { - c.stats.RecordHits(1) - c.setEntryAccessTime(en, now) - c.sendEvent(eventAccess, en) + // Check if this entry needs to be refreshed + now := currentTime() + if c.isExpired(en, now) { + c.stats.RecordMisses(1) + if c.loader == nil { + c.sendEvent(eventDelete, en) + } else { + // Update value if expired + c.setEntryAccessTime(en, now) + c.refresh(en) + } + } else { + c.stats.RecordHits(1) + c.setEntryAccessTime(en, now) + c.sendEvent(eventAccess, en) + } + return en.getValue(), nil + }) + var v V + if err != nil { + return v, err } - return en.getValue().(V), nil + v = val.(V) + return v, nil } // Refresh synchronously load and block until it value is loaded. @@ -274,9 +282,11 @@ func (c *localCache[K, V]) processEntries() { switch e.event { case eventWrite: c.write(e.entry) + e.Done() c.postWriteCleanup() case eventAccess: c.access(e.entry) + e.Done() c.postReadCleanup() case eventDelete: if e.entry == nil { @@ -284,6 +294,7 @@ func (c *localCache[K, V]) processEntries() { } else { c.remove(e.entry) } + e.Done() c.postReadCleanup() case eventClose: c.removeAll() @@ -293,10 +304,14 @@ func (c *localCache[K, V]) processEntries() { } // sendEvent sends event only when the cache is not closing/closed. -func (c *localCache[K, V]) sendEvent(typ event, en *entry) { +func (c *localCache[K, V]) sendEvent(typ event, en *entry) chan struct{} { + ch := make(chan struct{}) if atomic.LoadInt32(&c.closing) == 0 { - c.events <- entryEvent{en, typ} + c.events <- entryEvent{en, typ, ch} + return ch } + close(ch) + return ch } // This function must only be called from processEntries goroutine. @@ -349,29 +364,22 @@ func (c *localCache[K, V]) load(k K) (v V, err error) { return ret, errors.New("cache loader function must be set") } - // use singleflight here - val, err, _ := c.singleflight.Do(fmt.Sprintf("%v", k), func() (any, error) { - start := currentTime() - v, err := c.loader(k) - now := currentTime() - loadTime := now.Sub(start) - if err != nil { - c.stats.RecordLoadError(loadTime) - return v, err - } - c.stats.RecordLoadSuccess(loadTime) - en := newEntry(k, v, sum(k)) - c.setEntryWriteTime(en, now) - c.setEntryAccessTime(en, now) - c.sendEvent(eventWrite, en) - - return v, err - }) + start := currentTime() + v, err = c.loader(k) + now := currentTime() + loadTime := now.Sub(start) if err != nil { + c.stats.RecordLoadError(loadTime) return v, err } - v = val.(V) - return v, nil + c.stats.RecordLoadSuccess(loadTime) + en := newEntry(k, v, sum(k)) + c.setEntryWriteTime(en, now) + c.setEntryAccessTime(en, now) + // wait event processed + <-c.sendEvent(eventWrite, en) + + return v, err } // refresh reloads value for the given key. If loader returns an error, diff --git a/internal/util/cache/local_cache_test.go b/internal/util/cache/local_cache_test.go index 3e5b04d6bf..9199dfef2c 100644 --- a/internal/util/cache/local_cache_test.go +++ b/internal/util/cache/local_cache_test.go @@ -291,7 +291,7 @@ func TestExpireAfterWrite(t *testing.T) { assert.Equal(t, 2, loadCount) } -func TestRefreshAterWrite(t *testing.T) { +func TestRefreshAfterWrite(t *testing.T) { var mutex sync.Mutex loaded := make(map[int]int) loader := func(k int) (int, error) { diff --git a/internal/util/cache/policy.go b/internal/util/cache/policy.go index 059271a639..a3dca68d27 100644 --- a/internal/util/cache/policy.go +++ b/internal/util/cache/policy.go @@ -132,6 +132,14 @@ const ( type entryEvent struct { entry *entry event event + done chan struct{} +} + +// Done closes event signal channel. +func (e *entryEvent) Done() { + if e.done != nil { + close(e.done) + } } // cache is a data structure for cache entries.