diff --git a/pkg/util/cache/cache.go b/pkg/util/cache/cache.go index 213a421a55..dbeaeba60a 100644 --- a/pkg/util/cache/cache.go +++ b/pkg/util/cache/cache.go @@ -1,3 +1,19 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package cache import ( @@ -73,6 +89,16 @@ func (s *LazyScavenger[K]) Throw(key K) { s.size -= s.weight(key) } +type Stats struct { + HitCount atomic.Uint64 + MissCount atomic.Uint64 + LoadSuccessCount atomic.Uint64 + LoadFailCount atomic.Uint64 + TotalLoadTimeMs atomic.Uint64 + TotalFinalizeTimeMs atomic.Uint64 + EvictionCount atomic.Uint64 +} + type Cache[K comparable, V any] interface { // Do the operation `doer` on the given key `key`. The key is kept in the cache until the operation // completes. @@ -84,6 +110,9 @@ type Cache[K comparable, V any] interface { // Throws `ErrNoSuchItem` if the key is not found or not able to be loaded from given loader. // Throws `ErrTimeOut` if timed out. DoWait(key K, timeout time.Duration, doer func(V) error) (missing bool, err error) + + // Get stats + Stats() *Stats } type Waiter[K comparable] struct { @@ -105,8 +134,8 @@ type lruCache[K comparable, V any] struct { items map[K]*list.Element accessList *list.List loaderSingleFlight singleflight.Group - - waitQueue *list.List + stats *Stats + waitQueue *list.List loader Loader[K, V] finalizer Finalizer[K, V] @@ -171,6 +200,7 @@ func newLRUCache[K comparable, V any]( accessList: list.New(), waitQueue: list.New(), loaderSingleFlight: singleflight.Group{}, + stats: new(Stats), loader: loader, finalizer: finalizer, scavenger: scavenger, @@ -234,6 +264,10 @@ func (c *lruCache[K, V]) DoWait(key K, timeout time.Duration, doer func(V) error } } +func (c *lruCache[K, V]) Stats() *Stats { + return c.stats +} + func (c *lruCache[K, V]) Unpin(key K) { c.rwlock.Lock() defer c.rwlock.Unlock() @@ -243,7 +277,9 @@ func (c *lruCache[K, V]) Unpin(key K) { } item := e.Value.(*cacheItem[K, V]) item.pinCount.Dec() - c.notifyWaiters() + if item.pinCount.Load() == 0 { + c.notifyWaiters() + } } func (c *lruCache[K, V]) notifyWaiters() { @@ -271,9 +307,11 @@ func (c *lruCache[K, V]) peekAndPin(key K) *cacheItem[K, V] { // GetAndPin gets and pins the given key if it exists func (c *lruCache[K, V]) getAndPin(key K) (*cacheItem[K, V], bool, error) { if item := c.peekAndPin(key); item != nil { + c.stats.HitCount.Inc() return item, false, nil } + c.stats.MissCount.Inc() if c.loader != nil { // Try scavenge if there is room. If not, fail fast. // Note that the test is not accurate since we are not locking `loader` here. @@ -287,11 +325,16 @@ func (c *lruCache[K, V]) getAndPin(key K) (*cacheItem[K, V], bool, error) { return item, nil } + timer := time.Now() value, ok := c.loader(key) + c.stats.TotalLoadTimeMs.Add(uint64(time.Since(timer).Milliseconds())) + if !ok { + c.stats.LoadFailCount.Inc() return nil, ErrNoSuchItem } + c.stats.LoadSuccessCount.Inc() item, err := c.setAndPin(key, value) if err != nil { return nil, err @@ -360,10 +403,13 @@ func (c *lruCache[K, V]) setAndPin(key K, value V) (*cacheItem[K, V], error) { delete(c.items, ek) c.accessList.Remove(e) c.scavenger.Throw(ek) + c.stats.EvictionCount.Inc() if c.finalizer != nil { item := e.Value.(*cacheItem[K, V]) + timer := time.Now() c.finalizer(ek, item.value) + c.stats.TotalFinalizeTimeMs.Add(uint64(time.Since(timer).Milliseconds())) } } diff --git a/pkg/util/cache/cache_test.go b/pkg/util/cache/cache_test.go index c5a5b8612b..ca3b95271c 100644 --- a/pkg/util/cache/cache_test.go +++ b/pkg/util/cache/cache_test.go @@ -137,6 +137,68 @@ func TestLRUCache(t *testing.T) { }) } +func TestStats(t *testing.T) { + cacheBuilder := NewCacheBuilder[int, int]().WithLoader(func(key int) (int, bool) { + return key, true + }) + + t.Run("test loader", func(t *testing.T) { + size := 10 + cache := cacheBuilder.WithCapacity(int64(size)).Build() + stats := cache.Stats() + assert.Equal(t, uint64(0), stats.HitCount.Load()) + assert.Equal(t, uint64(0), stats.MissCount.Load()) + assert.Equal(t, uint64(0), stats.EvictionCount.Load()) + assert.Equal(t, uint64(0), stats.TotalLoadTimeMs.Load()) + assert.Equal(t, uint64(0), stats.TotalFinalizeTimeMs.Load()) + assert.Equal(t, uint64(0), stats.LoadSuccessCount.Load()) + assert.Equal(t, uint64(0), stats.LoadFailCount.Load()) + + for i := 0; i < size; i++ { + _, err := cache.Do(i, func(v int) error { + assert.Equal(t, i, v) + return nil + }) + assert.NoError(t, err) + } + assert.Equal(t, uint64(0), stats.HitCount.Load()) + assert.Equal(t, uint64(size), stats.MissCount.Load()) + assert.Equal(t, uint64(0), stats.EvictionCount.Load()) + // assert.True(t, stats.TotalLoadTimeMs.Load() > 0) + assert.Equal(t, uint64(0), stats.TotalFinalizeTimeMs.Load()) + assert.Equal(t, uint64(size), stats.LoadSuccessCount.Load()) + assert.Equal(t, uint64(0), stats.LoadFailCount.Load()) + + for i := 0; i < size; i++ { + _, err := cache.Do(i, func(v int) error { + assert.Equal(t, i, v) + return nil + }) + assert.NoError(t, err) + } + assert.Equal(t, uint64(size), stats.HitCount.Load()) + assert.Equal(t, uint64(size), stats.MissCount.Load()) + assert.Equal(t, uint64(0), stats.EvictionCount.Load()) + assert.Equal(t, uint64(0), stats.TotalFinalizeTimeMs.Load()) + assert.Equal(t, uint64(size), stats.LoadSuccessCount.Load()) + assert.Equal(t, uint64(0), stats.LoadFailCount.Load()) + + for i := size; i < size*2; i++ { + _, err := cache.Do(i, func(v int) error { + assert.Equal(t, i, v) + return nil + }) + assert.NoError(t, err) + } + assert.Equal(t, uint64(size), stats.HitCount.Load()) + assert.Equal(t, uint64(size*2), stats.MissCount.Load()) + assert.Equal(t, uint64(size), stats.EvictionCount.Load()) + // assert.True(t, stats.TotalFinalizeTimeMs.Load() > 0) + assert.Equal(t, uint64(size*2), stats.LoadSuccessCount.Load()) + assert.Equal(t, uint64(0), stats.LoadFailCount.Load()) + }) +} + func TestLRUCacheConcurrency(t *testing.T) { t.Run("test race condition", func(t *testing.T) { numEvict := new(atomic.Int32) diff --git a/pkg/util/cache/monitor.go b/pkg/util/cache/monitor.go new file mode 100644 index 0000000000..7edeabf156 --- /dev/null +++ b/pkg/util/cache/monitor.go @@ -0,0 +1,39 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cache + +import ( + "github.com/prometheus/client_golang/prometheus" +) + +// WIP: this function is a showcase of how to use prometheus, do not use it in production. +func PrometheusCacheMonitor[K comparable, V any](c Cache[K, V], namespace, subsystem string) { + hitRate := prometheus.NewGaugeFunc( + prometheus.GaugeOpts{ + Namespace: namespace, + Subsystem: subsystem, + Name: "cache_hitrate", + Help: "hit rate equals hitcount / (hitcount + misscount)", + }, + func() float64 { + hit := float64(c.Stats().HitCount.Load()) + miss := float64(c.Stats().MissCount.Load()) + return hit / (hit + miss) + }) + // TODO: adding more metrics. + prometheus.MustRegister(hitRate) +}