diff --git a/internal/util/cache/lru_cache.go b/internal/util/cache/lru_cache.go index 5694f3e5c8..263bab3eca 100644 --- a/internal/util/cache/lru_cache.go +++ b/internal/util/cache/lru_cache.go @@ -18,24 +18,25 @@ package cache import ( "container/list" - "context" "errors" "fmt" "sync" ) +// LRU generic utility for lru cache. type LRU[K comparable, V any] struct { - ctx context.Context - cancel context.CancelFunc evictList *list.List items map[K]*list.Element capacity int onEvicted func(k K, v V) m sync.RWMutex evictedCh chan *entry[K, V] + closeCh chan struct{} + closeOnce sync.Once stats *Stats } +// Stats is the model for cache statistics. type Stats struct { hitCount float32 evictedCount float32 @@ -43,6 +44,7 @@ type Stats struct { writeCount float32 } +// String implement stringer for printing. func (s *Stats) String() string { var hitRatio float32 var evictedRatio float32 @@ -59,29 +61,30 @@ type entry[K comparable, V any] struct { value V } +// NewLRU creates a LRU cache with provided capacity and `onEvicted` function. +// `onEvicted` will be executed when an item is chosed to be evicted. func NewLRU[K comparable, V any](capacity int, onEvicted func(k K, v V)) (*LRU[K, V], error) { if capacity <= 0 { return nil, errors.New("cache size must be positive") } - ctx, cancel := context.WithCancel(context.Background()) c := &LRU[K, V]{ - ctx: ctx, - cancel: cancel, capacity: capacity, evictList: list.New(), items: make(map[K]*list.Element), onEvicted: onEvicted, evictedCh: make(chan *entry[K, V], 16), + closeCh: make(chan struct{}), stats: &Stats{}, } go c.evictedWorker() return c, nil } +// evictedWorker executes onEvicted function for each evicted items. func (c *LRU[K, V]) evictedWorker() { for { select { - case <-c.ctx.Done(): + case <-c.closeCh: return case e, ok := <-c.evictedCh: if ok { @@ -93,9 +96,27 @@ func (c *LRU[K, V]) evictedWorker() { } } +// closed returns whether cache is closed. +func (c *LRU[K, V]) closed() bool { + select { + case <-c.closeCh: + return true + default: + return false + } +} + +// Add puts an item into cache. func (c *LRU[K, V]) Add(key K, value V) { c.m.Lock() defer c.m.Unlock() + + if c.closed() { + // evict since cache closed + c.onEvicted(key, value) + return + } + c.stats.writeCount++ if e, ok := c.items[key]; ok { c.evictList.MoveToFront(e) @@ -120,9 +141,17 @@ func (c *LRU[K, V]) Add(key K, value V) { } } +// Get returns value for provided key. func (c *LRU[K, V]) Get(key K) (value V, ok bool) { c.m.RLock() defer c.m.RUnlock() + + var zeroV V + if c.closed() { + // cache closed, returns nothing + return zeroV, false + } + c.stats.readCount++ if e, ok := c.items[key]; ok { c.stats.hitCount++ @@ -131,13 +160,18 @@ func (c *LRU[K, V]) Get(key K) (value V, ok bool) { return kv.value, true } - var zeroV V return zeroV, false } +// Remove removes item associated with provided key. func (c *LRU[K, V]) Remove(key K) { c.m.Lock() defer c.m.Unlock() + + if c.closed() { + return + } + if e, ok := c.items[key]; ok { c.evictList.Remove(e) kv := e.Value.(*entry[K, V]) @@ -148,16 +182,24 @@ func (c *LRU[K, V]) Remove(key K) { } } +// Contains returns whether items with provided key exists in cache. func (c *LRU[K, V]) Contains(key K) bool { c.m.RLock() defer c.m.RUnlock() + if c.closed() { + return false + } _, ok := c.items[key] return ok } +// Keys returns all the keys exist in cache. func (c *LRU[K, V]) Keys() []K { c.m.RLock() defer c.m.RUnlock() + if c.closed() { + return nil + } keys := make([]K, len(c.items)) i := 0 for ent := c.evictList.Back(); ent != nil; ent = ent.Prev() { @@ -167,16 +209,22 @@ func (c *LRU[K, V]) Keys() []K { return keys } +// Len returns items count in cache. func (c *LRU[K, V]) Len() int { c.m.RLock() defer c.m.RUnlock() + if c.closed() { + return 0 + } return c.evictList.Len() } +// Capacity returns cache capacity. func (c *LRU[K, V]) Capacity() int { return c.capacity } +// Purge removes all items and put them into evictedCh. func (c *LRU[K, V]) Purge() { c.m.Lock() defer c.m.Unlock() @@ -189,9 +237,14 @@ func (c *LRU[K, V]) Purge() { c.evictList.Init() } +// Resize changes the capacity of cache. func (c *LRU[K, V]) Resize(capacity int) int { c.m.Lock() defer c.m.Unlock() + if c.closed() { + return 0 + } + c.capacity = capacity if capacity >= c.evictList.Len() { return 0 @@ -211,35 +264,48 @@ func (c *LRU[K, V]) Resize(capacity int) int { return diff } +// GetOldest returns the oldest item in cache. func (c *LRU[K, V]) GetOldest() (K, V, bool) { c.m.RLock() defer c.m.RUnlock() + var ( + zeroK K + zeroV V + ) + if c.closed() { + return zeroK, zeroV, false + } ent := c.evictList.Back() if ent != nil { kv := ent.Value.(*entry[K, V]) return kv.key, kv.value, true } - var ( - zeroK K - zeroV V - ) return zeroK, zeroV, false } +// Close cleans up the cache resources. func (c *LRU[K, V]) Close() { - c.Purge() - c.cancel() - remain := len(c.evictedCh) - for i := 0; i < remain; i++ { - e, ok := <-c.evictedCh - if ok { + c.closeOnce.Do(func() { + // fetch lock to + // - wait on-going operations done + // - block incoming operations + c.m.Lock() + close(c.closeCh) + c.m.Unlock() + + // execute purge in a goroutine, otherwise Purge may block forever putting evictedCh + go func() { + c.Purge() + close(c.evictedCh) + }() + for e := range c.evictedCh { c.onEvicted(e.key, e.value) } - } - close(c.evictedCh) + }) } +// Stats returns cache statistics. func (c *LRU[K, V]) Stats() *Stats { return c.stats } diff --git a/internal/util/cache/lru_cache_test.go b/internal/util/cache/lru_cache_test.go index 700bf74df0..b66a49480a 100644 --- a/internal/util/cache/lru_cache_test.go +++ b/internal/util/cache/lru_cache_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewLRU(t *testing.T) { @@ -399,3 +400,41 @@ func TestLRU_Resize(t *testing.T) { return atomic.LoadInt32(&evicted) == 1 }, 1*time.Second, 100*time.Millisecond) } + +func TestLRU_closed(t *testing.T) { + evicted := int32(0) + c, err := NewLRU(2, func(string, string) { atomic.AddInt32(&evicted, 1) }) + require.NoError(t, err) + + c.Close() + + c.Add("testKey", "testValue") + assert.Equal(t, int32(1), evicted) + + _, ok := c.Get("testKey") + assert.False(t, ok) + + assert.NotPanics(t, func() { + c.Remove("testKey") + }) + + contains := c.Contains("testKey") + assert.False(t, contains) + + keys := c.Keys() + assert.Nil(t, keys) + + l := c.Len() + assert.Equal(t, 0, l) + + diff := c.Resize(1) + assert.Equal(t, 0, diff) + assert.Equal(t, 2, c.Capacity()) + + _, _, ok = c.GetOldest() + assert.False(t, ok) + + assert.NotPanics(t, func() { + c.Close() + }, "double close") +}