diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index 7a4aefdc25..a8ea81bed8 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -54,6 +54,7 @@ type Cache interface { // GetCollectionSchema get collection's schema. GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) + ClearShards(collectionName string) RemoveCollection(ctx context.Context, collectionName string) RemovePartition(ctx context.Context, collectionName string, partitionName string) @@ -541,8 +542,6 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam zap.String("collectionName", collectionName)) } - m.mu.Lock() - defer m.mu.Unlock() req := &querypb.GetShardLeadersRequest{ Base: &commonpb.MsgBase{ MsgType: commonpb.MsgType_GetShardLeaders, @@ -560,6 +559,23 @@ func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionNam shards := resp.GetShards() + m.mu.Lock() m.collInfo[collectionName].shardLeaders = shards + m.mu.Unlock() + return shards, nil } + +// ClearShards clear the shard leader cache of a collection +func (m *MetaCache) ClearShards(collectionName string) { + log.Info("clearing shard cache for collection", zap.String("collectionName", collectionName)) + m.mu.Lock() + defer m.mu.Unlock() + _, ok := m.collInfo[collectionName] + + if !ok { + return + } + + m.collInfo[collectionName].shardLeaders = nil +} diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index 718aad0c68..19f8edb165 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -356,5 +356,46 @@ func TestMetaCache_GetShards(t *testing.T) { assert.Equal(t, 3, len(shards[0].GetNodeAddrs())) assert.Equal(t, 3, len(shards[0].GetNodeIds())) }) +} + +func TestMetaCache_ClearShards(t *testing.T) { + client := &MockRootCoordClientInterface{} + err := InitMetaCache(client) + require.Nil(t, err) + + var ( + ctx = context.TODO() + collectionName = "collection1" + qc = NewQueryCoordMock() + ) + qc.Init() + qc.Start() + defer qc.Stop() + + t.Run("Clear with no collection info", func(t *testing.T) { + globalMetaCache.ClearShards("collection_not_exist") + }) + + t.Run("Clear valid collection empty cache", func(t *testing.T) { + globalMetaCache.ClearShards(collectionName) + }) + + t.Run("Clear valid collection valid cache", func(t *testing.T) { + + qc.validShardLeaders = true + shards, err := globalMetaCache.GetShards(ctx, true, collectionName, qc) + require.NoError(t, err) + require.NotEmpty(t, shards) + require.Equal(t, 1, len(shards)) + require.Equal(t, 3, len(shards[0].GetNodeAddrs())) + require.Equal(t, 3, len(shards[0].GetNodeIds())) + + globalMetaCache.ClearShards(collectionName) + + qc.validShardLeaders = false + shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc) + assert.Error(t, err) + assert.Empty(t, shards) + }) } diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 45c399c6c4..9fa4390a8d 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -2864,6 +2864,7 @@ func (rct *releaseCollectionTask) Execute(ctx context.Context) (err error) { } func (rct *releaseCollectionTask) PostExecute(ctx context.Context) error { + globalMetaCache.ClearShards(rct.CollectionName) return nil } @@ -3056,6 +3057,7 @@ func (rpt *releasePartitionsTask) Execute(ctx context.Context) (err error) { } func (rpt *releasePartitionsTask) PostExecute(ctx context.Context) error { + globalMetaCache.ClearShards(rpt.CollectionName) return nil }