diff --git a/internal/proto/query_coord.proto b/internal/proto/query_coord.proto index ed0e152bbd..75cfe78599 100644 --- a/internal/proto/query_coord.proto +++ b/internal/proto/query_coord.proto @@ -187,6 +187,7 @@ message ShowCollectionsResponse { repeated int64 inMemory_percentages = 3; repeated bool query_service_available = 4; repeated int64 refresh_progress = 5; + repeated schema.LongArray load_fields = 6; } message ShowPartitionsRequest { @@ -214,6 +215,7 @@ message LoadCollectionRequest { bool refresh = 7; // resource group names repeated string resource_groups = 8; + repeated int64 load_fields = 9; } message ReleaseCollectionRequest { @@ -244,6 +246,7 @@ message LoadPartitionsRequest { // resource group names repeated string resource_groups = 9; repeated index.IndexInfo index_info_list = 10; + repeated int64 load_fields = 11; } message ReleasePartitionsRequest { @@ -313,6 +316,7 @@ message LoadMetaInfo { string metric_type = 4 [deprecated = true]; string db_name = 5; // Only used for metrics label. string resource_group = 6; // Only used for metrics label. + repeated int64 load_fields = 7; } message WatchDmChannelsRequest { @@ -650,6 +654,7 @@ message CollectionLoadInfo { map field_indexID = 5; LoadType load_type = 6; int32 recover_times = 7; + repeated int64 load_fields = 8; } message PartitionLoadInfo { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index f55f5fe838..886049dee3 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -123,7 +123,7 @@ func (node *Proxy) InvalidateCollectionMetaCache(ctx context.Context, request *p if globalMetaCache != nil { switch msgType { - case commonpb.MsgType_DropCollection, commonpb.MsgType_RenameCollection, commonpb.MsgType_DropAlias, commonpb.MsgType_AlterAlias: + case commonpb.MsgType_DropCollection, commonpb.MsgType_RenameCollection, commonpb.MsgType_DropAlias, commonpb.MsgType_AlterAlias, commonpb.MsgType_LoadCollection: if collectionName != "" { globalMetaCache.RemoveCollection(ctx, request.GetDbName(), collectionName) // no need to return error, though collection may be not cached globalMetaCache.DeprecateShardCache(request.GetDbName(), collectionName) diff --git a/internal/proxy/impl_test.go b/internal/proxy/impl_test.go index c9f7f2439e..6104b99be5 100644 --- a/internal/proxy/impl_test.go +++ b/internal/proxy/impl_test.go @@ -249,6 +249,8 @@ func TestProxy_ResourceGroup(t *testing.T) { qc := mocks.NewMockQueryCoordClient(t) node.SetQueryCoordClient(qc) + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() + tsoAllocatorIns := newMockTsoAllocator() node.sched, err = newTaskScheduler(node.ctx, tsoAllocatorIns, node.factory) assert.NoError(t, err) diff --git a/internal/proxy/lb_policy_test.go b/internal/proxy/lb_policy_test.go index 86551df24a..a83e1a5c06 100644 --- a/internal/proxy/lb_policy_test.go +++ b/internal/proxy/lb_policy_test.go @@ -67,6 +67,7 @@ func (s *LBPolicySuite) SetupTest() { successStatus := commonpb.Status{ErrorCode: commonpb.ErrorCode_Success} qc := mocks.NewMockQueryCoordClient(s.T()) qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(&successStatus, nil) + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ Status: &successStatus, diff --git a/internal/proxy/meta_cache.go b/internal/proxy/meta_cache.go index e74940c4d7..49bc8a4920 100644 --- a/internal/proxy/meta_cache.go +++ b/internal/proxy/meta_cache.go @@ -128,7 +128,7 @@ type schemaInfo struct { schemaHelper *typeutil.SchemaHelper } -func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { +func newSchemaInfoWithLoadFields(schema *schemapb.CollectionSchema, loadFields []int64) *schemaInfo { fieldMap := typeutil.NewConcurrentMap[string, int64]() hasPartitionkey := false var pkField *schemapb.FieldSchema @@ -142,7 +142,7 @@ func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { } } // schema shall be verified before - schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + schemaHelper, _ := typeutil.CreateSchemaHelperWithLoadFields(schema, loadFields) return &schemaInfo{ CollectionSchema: schema, fieldMap: fieldMap, @@ -152,6 +152,10 @@ func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { } } +func newSchemaInfo(schema *schemapb.CollectionSchema) *schemaInfo { + return newSchemaInfoWithLoadFields(schema, nil) +} + func (s *schemaInfo) MapFieldID(name string) (int64, bool) { return s.fieldMap.Get(name) } @@ -167,6 +171,83 @@ func (s *schemaInfo) GetPkField() (*schemapb.FieldSchema, error) { return s.pkField, nil } +// GetLoadFieldIDs returns field id for load field list. +// If input `loadFields` is empty, use collection schema definition. +// Otherwise, perform load field list constraint check then return field id. +func (s *schemaInfo) GetLoadFieldIDs(loadFields []string, skipDynamicField bool) ([]int64, error) { + if len(loadFields) == 0 { + // skip check logic since create collection already did the rule check already + return common.GetCollectionLoadFields(s.CollectionSchema, skipDynamicField), nil + } + + fieldIDs := typeutil.NewSet[int64]() + // fieldIDs := make([]int64, 0, len(loadFields)) + fields := make([]*schemapb.FieldSchema, 0, len(loadFields)) + for _, name := range loadFields { + fieldSchema, err := s.schemaHelper.GetFieldFromName(name) + if err != nil { + return nil, err + } + + fields = append(fields, fieldSchema) + fieldIDs.Insert(fieldSchema.GetFieldID()) + } + + // only append dynamic field when skipFlag == false + if !skipDynamicField { + // find dynamic field + dynamicField := lo.FindOrElse(s.Fields, nil, func(field *schemapb.FieldSchema) bool { + return field.IsDynamic + }) + + // if dynamic field not nil + if dynamicField != nil { + fieldIDs.Insert(dynamicField.GetFieldID()) + fields = append(fields, dynamicField) + } + } + + // validate load fields list + if err := s.validateLoadFields(loadFields, fields); err != nil { + return nil, err + } + + return fieldIDs.Collect(), nil +} + +func (s *schemaInfo) validateLoadFields(names []string, fields []*schemapb.FieldSchema) error { + // ignore error if not found + partitionKeyField, _ := s.schemaHelper.GetPartitionKeyField() + + var hasPrimaryKey, hasPartitionKey, hasVector bool + for _, field := range fields { + if field.GetFieldID() == s.pkField.GetFieldID() { + hasPrimaryKey = true + } + if typeutil.IsVectorType(field.GetDataType()) { + hasVector = true + } + if field.IsPartitionKey { + hasPartitionKey = true + } + } + + if !hasPrimaryKey { + return merr.WrapErrParameterInvalidMsg("load field list %v does not contain primary key field %s", names, s.pkField.GetName()) + } + if !hasVector { + return merr.WrapErrParameterInvalidMsg("load field list %v does not contain vector field", names) + } + if partitionKeyField != nil && !hasPartitionKey { + return merr.WrapErrParameterInvalidMsg("load field list %v does not contain partition key field %s", names, partitionKeyField.GetName()) + } + return nil +} + +func (s *schemaInfo) IsFieldLoaded(fieldID int64) bool { + return s.schemaHelper.IsFieldLoaded(fieldID) +} + // partitionInfos contains the cached collection partition informations. type partitionInfos struct { partitionInfos []*partitionInfo @@ -366,6 +447,11 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, return nil, err } + loadFields, err := m.getCollectionLoadFields(ctx, collection.CollectionID) + if err != nil { + return nil, err + } + // check partitionID, createdTimestamp and utcstamp has sam element numbers if len(partitions.PartitionNames) != len(partitions.CreatedTimestamps) || len(partitions.PartitionNames) != len(partitions.CreatedUtcTimestamps) { return nil, merr.WrapErrParameterInvalidMsg("partition names and timestamps number is not aligned, response: %s", partitions.String()) @@ -393,7 +479,7 @@ func (m *MetaCache) update(ctx context.Context, database, collectionName string, return nil, err } - schemaInfo := newSchemaInfo(collection.Schema) + schemaInfo := newSchemaInfoWithLoadFields(collection.Schema, loadFields) m.collInfo[database][collectionName] = &collectionInfo{ collID: collection.CollectionID, schema: schemaInfo, @@ -760,6 +846,28 @@ func (m *MetaCache) showPartitions(ctx context.Context, dbName string, collectio return partitions, nil } +func (m *MetaCache) getCollectionLoadFields(ctx context.Context, collectionID UniqueID) ([]int64, error) { + req := &querypb.ShowCollectionsRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + CollectionIDs: []int64{collectionID}, + } + + resp, err := m.queryCoord.ShowCollections(ctx, req) + if err != nil { + if errors.Is(err, merr.ErrCollectionNotLoaded) { + return []int64{}, nil + } + return nil, err + } + // backward compatility, ignore HPL logic + if len(resp.GetLoadFields()) < 1 { + return []int64{}, nil + } + return resp.GetLoadFields()[0].GetData(), nil +} + func (m *MetaCache) describeDatabase(ctx context.Context, dbName string) (*rootcoordpb.DescribeDatabaseResponse, error) { req := &rootcoordpb.DescribeDatabaseRequest{ DbName: dbName, diff --git a/internal/proxy/meta_cache_test.go b/internal/proxy/meta_cache_test.go index f2459b674a..ec8fe63c37 100644 --- a/internal/proxy/meta_cache_test.go +++ b/internal/proxy/meta_cache_test.go @@ -39,6 +39,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/util/crypto" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -195,6 +196,9 @@ func TestMetaCache_GetCollection(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() + mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -245,6 +249,8 @@ func TestMetaCache_GetBasicCollectionInfo(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -277,6 +283,7 @@ func TestMetaCache_GetCollectionName(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -327,6 +334,7 @@ func TestMetaCache_GetCollectionFailure(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -360,6 +368,7 @@ func TestMetaCache_GetNonExistCollection(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -376,6 +385,7 @@ func TestMetaCache_GetPartitionID(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -398,6 +408,7 @@ func TestMetaCache_ConcurrentTest1(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -452,6 +463,7 @@ func TestMetaCache_GetPartitionError(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, mgr) assert.NoError(t, err) @@ -805,6 +817,7 @@ func TestMetaCache_Database(t *testing.T) { ctx := context.Background() rootCoord := &MockRootCoordClientInterface{} queryCoord := &mocks.MockQueryCoordClient{} + queryCoord.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() shardMgr := newShardClientMgr() err := InitMetaCache(ctx, rootCoord, queryCoord, shardMgr) assert.NoError(t, err) @@ -1119,3 +1132,212 @@ func TestMetaCache_InvalidateShardLeaderCache(t *testing.T) { assert.Len(t, nodeInfos["channel-1"], 3) assert.Equal(t, called.Load(), int32(2)) } + +func TestSchemaInfo_GetLoadFieldIDs(t *testing.T) { + type testCase struct { + tag string + schema *schemapb.CollectionSchema + loadFields []string + skipDynamicField bool + expectResult []int64 + expectErr bool + } + + rowIDField := &schemapb.FieldSchema{ + FieldID: common.RowIDField, + Name: common.RowIDFieldName, + DataType: schemapb.DataType_Int64, + } + timestampField := &schemapb.FieldSchema{ + FieldID: common.TimeStampField, + Name: common.TimeStampFieldName, + DataType: schemapb.DataType_Int64, + } + pkField := &schemapb.FieldSchema{ + FieldID: common.StartOfUserFieldID, + Name: "pk", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + } + scalarField := &schemapb.FieldSchema{ + FieldID: common.StartOfUserFieldID + 1, + Name: "text", + DataType: schemapb.DataType_VarChar, + } + scalarFieldSkipLoad := &schemapb.FieldSchema{ + FieldID: common.StartOfUserFieldID + 1, + Name: "text", + DataType: schemapb.DataType_VarChar, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.FieldSkipLoadKey, Value: "true"}, + }, + } + partitionKeyField := &schemapb.FieldSchema{ + FieldID: common.StartOfUserFieldID + 2, + Name: "part_key", + DataType: schemapb.DataType_Int64, + IsPartitionKey: true, + } + vectorField := &schemapb.FieldSchema{ + FieldID: common.StartOfUserFieldID + 3, + Name: "vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "768"}, + }, + } + dynamicField := &schemapb.FieldSchema{ + FieldID: common.StartOfUserFieldID + 4, + Name: common.MetaFieldName, + DataType: schemapb.DataType_JSON, + IsDynamic: true, + } + + testCases := []testCase{ + { + tag: "default", + schema: &schemapb.CollectionSchema{ + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + rowIDField, + timestampField, + pkField, + scalarField, + partitionKeyField, + vectorField, + dynamicField, + }, + }, + loadFields: nil, + skipDynamicField: false, + expectResult: []int64{common.StartOfUserFieldID, common.StartOfUserFieldID + 1, common.StartOfUserFieldID + 2, common.StartOfUserFieldID + 3, common.StartOfUserFieldID + 4}, + expectErr: false, + }, + { + tag: "default_from_schema", + schema: &schemapb.CollectionSchema{ + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + rowIDField, + timestampField, + pkField, + scalarFieldSkipLoad, + partitionKeyField, + vectorField, + dynamicField, + }, + }, + loadFields: nil, + skipDynamicField: false, + expectResult: []int64{common.StartOfUserFieldID, common.StartOfUserFieldID + 2, common.StartOfUserFieldID + 3, common.StartOfUserFieldID + 4}, + expectErr: false, + }, + { + tag: "load_fields", + schema: &schemapb.CollectionSchema{ + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + rowIDField, + timestampField, + pkField, + scalarField, + partitionKeyField, + vectorField, + dynamicField, + }, + }, + loadFields: []string{"pk", "part_key", "vector"}, + skipDynamicField: false, + expectResult: []int64{common.StartOfUserFieldID, common.StartOfUserFieldID + 2, common.StartOfUserFieldID + 3, common.StartOfUserFieldID + 4}, + expectErr: false, + }, + { + tag: "load_fields_skip_dynamic", + schema: &schemapb.CollectionSchema{ + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + rowIDField, + timestampField, + pkField, + scalarField, + partitionKeyField, + vectorField, + dynamicField, + }, + }, + loadFields: []string{"pk", "part_key", "vector"}, + skipDynamicField: true, + expectResult: []int64{common.StartOfUserFieldID, common.StartOfUserFieldID + 2, common.StartOfUserFieldID + 3}, + expectErr: false, + }, + { + tag: "pk_not_loaded", + schema: &schemapb.CollectionSchema{ + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + rowIDField, + timestampField, + pkField, + scalarField, + partitionKeyField, + vectorField, + dynamicField, + }, + }, + loadFields: []string{"part_key", "vector"}, + skipDynamicField: true, + expectErr: true, + }, + { + tag: "part_key_not_loaded", + schema: &schemapb.CollectionSchema{ + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + rowIDField, + timestampField, + pkField, + scalarField, + partitionKeyField, + vectorField, + dynamicField, + }, + }, + loadFields: []string{"pk", "vector"}, + skipDynamicField: true, + expectErr: true, + }, + { + tag: "vector_not_loaded", + schema: &schemapb.CollectionSchema{ + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + rowIDField, + timestampField, + pkField, + scalarField, + partitionKeyField, + vectorField, + dynamicField, + }, + }, + loadFields: []string{"pk", "part_key"}, + skipDynamicField: true, + expectErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.tag, func(t *testing.T) { + info := newSchemaInfo(tc.schema) + + result, err := info.GetLoadFieldIDs(tc.loadFields, tc.skipDynamicField) + if tc.expectErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.ElementsMatch(t, tc.expectResult, result) + }) + } +} diff --git a/internal/proxy/msg_pack_test.go b/internal/proxy/msg_pack_test.go index 29873f11b0..9d199dd166 100644 --- a/internal/proxy/msg_pack_test.go +++ b/internal/proxy/msg_pack_test.go @@ -29,6 +29,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/paramtable" @@ -153,8 +155,10 @@ func TestRepackInsertDataWithPartitionKey(t *testing.T) { rc := NewRootCoordMock() defer rc.Close() + qc := &mocks.MockQueryCoordClient{} + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() - err := InitMetaCache(ctx, rc, nil, nil) + err := InitMetaCache(ctx, rc, qc, nil) assert.NoError(t, err) idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID()) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index fc01844170..e4b80f54d8 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1611,6 +1611,13 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { if err != nil { return err } + // prepare load field list + // TODO use load collection load field list after proto merged + loadFields, err := collSchema.GetLoadFieldIDs(t.GetLoadFields(), t.GetSkipLoadDynamicField()) + if err != nil { + return err + } + // check index indexResponse, err := t.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ CollectionID: collID, @@ -1658,6 +1665,7 @@ func (t *loadCollectionTask) Execute(ctx context.Context) (err error) { FieldIndexID: fieldIndexIDs, Refresh: t.Refresh, ResourceGroups: t.ResourceGroups, + LoadFields: loadFields, } log.Debug("send LoadCollectionRequest to query coordinator", zap.Any("schema", request.Schema)) @@ -1855,6 +1863,11 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { if err != nil { return err } + // prepare load field list + loadFields, err := collSchema.GetLoadFieldIDs(t.GetLoadFields(), t.GetSkipLoadDynamicField()) + if err != nil { + return err + } // check index indexResponse, err := t.datacoord.DescribeIndex(ctx, &indexpb.DescribeIndexRequest{ CollectionID: collID, @@ -1908,6 +1921,7 @@ func (t *loadPartitionsTask) Execute(ctx context.Context) error { FieldIndexID: fieldIndexIDs, Refresh: t.Refresh, ResourceGroups: t.ResourceGroups, + LoadFields: loadFields, } t.result, err = t.queryCoord.LoadPartitions(ctx, request) if err = merr.CheckRPCCall(t.result, err); err != nil { diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 47251cc2ed..4f709f2a9c 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -188,6 +188,18 @@ func TestDropIndexTask_PreExecute(t *testing.T) { t.Run("coll has been loaded", func(t *testing.T) { qc := getMockQueryCoord() + qc.ExpectedCalls = nil + qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil) + qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ + Status: merr.Success(), + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: "channel-1", + NodeIds: []int64{1, 2, 3}, + NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, + }, + }, + }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ Status: merr.Success(), CollectionIDs: []int64{collectionID}, @@ -200,6 +212,22 @@ func TestDropIndexTask_PreExecute(t *testing.T) { t.Run("show collection error", func(t *testing.T) { qc := getMockQueryCoord() + qc.ExpectedCalls = nil + qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil) + qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ + Status: merr.Success(), + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: "channel-1", + NodeIds: []int64{1, 2, 3}, + NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, + }, + }, + }, nil) + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ + Status: merr.Success(), + CollectionIDs: []int64{collectionID}, + }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(nil, errors.New("error")) dit.queryCoord = qc @@ -209,6 +237,22 @@ func TestDropIndexTask_PreExecute(t *testing.T) { t.Run("show collection fail", func(t *testing.T) { qc := getMockQueryCoord() + qc.ExpectedCalls = nil + qc.EXPECT().LoadCollection(mock.Anything, mock.Anything).Return(merr.Success(), nil) + qc.EXPECT().GetShardLeaders(mock.Anything, mock.Anything).Return(&querypb.GetShardLeadersResponse{ + Status: merr.Success(), + Shards: []*querypb.ShardLeadersList{ + { + ChannelName: "channel-1", + NodeIds: []int64{1, 2, 3}, + NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"}, + }, + }, + }, nil) + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ + Status: merr.Success(), + CollectionIDs: []int64{collectionID}, + }, nil) qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ Status: &commonpb.Status{ ErrorCode: commonpb.ErrorCode_UnexpectedError, @@ -236,6 +280,7 @@ func getMockQueryCoord() *mocks.MockQueryCoordClient { }, }, }, nil) + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() return qc } diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index 06d318f6c7..592a90e0f3 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -73,6 +73,9 @@ func TestQueryTask_all(t *testing.T) { }, }, }, nil).Maybe() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ + Status: &successStatus, + }, nil).Maybe() mgr := NewMockShardClientManager(t) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 66cb0dba90..b0f9d769b2 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -60,6 +60,9 @@ func TestSearchTask_PostExecute(t *testing.T) { defer rc.Close() require.NoError(t, err) mgr := newShardClientMgr() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{ + Status: merr.Success(), + }, nil).Maybe() err = InitMetaCache(ctx, rc, qc, mgr) require.NoError(t, err) @@ -191,6 +194,7 @@ func TestSearchTask_PreExecute(t *testing.T) { defer rc.Close() require.NoError(t, err) mgr := newShardClientMgr() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() err = InitMetaCache(ctx, rc, qc, mgr) require.NoError(t, err) @@ -335,6 +339,7 @@ func TestSearchTaskV2_Execute(t *testing.T) { defer rc.Close() mgr := newShardClientMgr() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() err = InitMetaCache(ctx, rc, qc, mgr) require.NoError(t, err) @@ -1786,6 +1791,7 @@ func TestSearchTask_ErrExecute(t *testing.T) { ) qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() mgr := NewMockShardClientManager(t) mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe() diff --git a/internal/proxy/task_statistic_test.go b/internal/proxy/task_statistic_test.go index c3438d0f13..42f0d63b44 100644 --- a/internal/proxy/task_statistic_test.go +++ b/internal/proxy/task_statistic_test.go @@ -68,6 +68,7 @@ func (s *StatisticTaskSuite) SetupTest() { }, }, }, nil).Maybe() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{ Status: merr.Success(), PartitionIDs: []int64{1, 2, 3}, diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index a85813c633..ed414ad831 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -977,6 +977,7 @@ func TestHasCollectionTask(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() mgr := newShardClientMgr() @@ -1123,6 +1124,7 @@ func TestDescribeCollectionTask_ShardsNum1(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() mgr := newShardClientMgr() @@ -1185,6 +1187,7 @@ func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) { rc := NewRootCoordMock() defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() mgr := newShardClientMgr() InitMetaCache(ctx, rc, qc, mgr) @@ -1248,6 +1251,7 @@ func TestDescribeCollectionTask_ShardsNum2(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() mgr := newShardClientMgr() @@ -1611,6 +1615,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() @@ -1803,6 +1808,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() @@ -2676,6 +2682,7 @@ func TestCreateResourceGroupTask(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil) ctx := context.Background() @@ -2716,6 +2723,7 @@ func TestDropResourceGroupTask(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(merr.Success(), nil) ctx := context.Background() @@ -2756,6 +2764,7 @@ func TestTransferNodeTask(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(merr.Success(), nil) ctx := context.Background() @@ -2796,6 +2805,7 @@ func TestTransferNodeTask(t *testing.T) { func TestTransferReplicaTask(t *testing.T) { rc := &MockRootCoordClientInterface{} qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(merr.Success(), nil) ctx := context.Background() @@ -2839,6 +2849,7 @@ func TestTransferReplicaTask(t *testing.T) { func TestListResourceGroupsTask(t *testing.T) { rc := &MockRootCoordClientInterface{} qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{ Status: merr.Success(), ResourceGroups: []string{meta.DefaultResourceGroupName, "rg"}, @@ -2882,6 +2893,7 @@ func TestListResourceGroupsTask(t *testing.T) { func TestDescribeResourceGroupTask(t *testing.T) { rc := &MockRootCoordClientInterface{} qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ Status: merr.Success(), ResourceGroup: &querypb.ResourceGroupInfo{ @@ -2937,6 +2949,7 @@ func TestDescribeResourceGroupTask(t *testing.T) { func TestDescribeResourceGroupTaskFailed(t *testing.T) { rc := &MockRootCoordClientInterface{} qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, }, nil) @@ -3142,8 +3155,11 @@ func TestCreateCollectionTaskWithPartitionKey(t *testing.T) { err = task.Execute(ctx) assert.NoError(t, err) + qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() + // check default partitions - err = InitMetaCache(ctx, rc, nil, nil) + err = InitMetaCache(ctx, rc, qc, nil) assert.NoError(t, err) partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", task.CollectionName) assert.NoError(t, err) @@ -3222,6 +3238,7 @@ func TestPartitionKey(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() @@ -3477,6 +3494,7 @@ func TestClusteringKey(t *testing.T) { defer rc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() ctx := context.Background() @@ -3659,6 +3677,7 @@ func TestTaskPartitionKeyIsolation(t *testing.T) { dc := NewDataCoordMock() defer dc.Close() qc := getQueryCoordClient() + qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe() defer qc.Close() ctx := context.Background() mgr := newShardClientMgr() diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 2c1c60abc1..2555308d47 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -988,7 +988,8 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int // output_fields=["*",C] ==> [A,B,C,D] func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary bool) ([]string, []string, error) { var primaryFieldName string - allFieldNameMap := make(map[string]bool) + var dynamicField *schemapb.FieldSchema + allFieldNameMap := make(map[string]int64) resultFieldNameMap := make(map[string]bool) resultFieldNames := make([]string, 0) userOutputFieldsMap := make(map[string]bool) @@ -998,37 +999,56 @@ func translateOutputFields(outputFields []string, schema *schemaInfo, addPrimary if field.IsPrimaryKey { primaryFieldName = field.Name } - allFieldNameMap[field.Name] = true + if field.IsDynamic { + dynamicField = field + } + allFieldNameMap[field.Name] = field.GetFieldID() } for _, outputFieldName := range outputFields { outputFieldName = strings.TrimSpace(outputFieldName) if outputFieldName == "*" { - for fieldName := range allFieldNameMap { - resultFieldNameMap[fieldName] = true - userOutputFieldsMap[fieldName] = true + for fieldName, fieldID := range allFieldNameMap { + // skip Cold field + if schema.IsFieldLoaded(fieldID) { + resultFieldNameMap[fieldName] = true + userOutputFieldsMap[fieldName] = true + } } } else { - if _, ok := allFieldNameMap[outputFieldName]; ok { - resultFieldNameMap[outputFieldName] = true - userOutputFieldsMap[outputFieldName] = true - } else { - if schema.EnableDynamicField { - err := planparserv2.ParseIdentifier(schema.schemaHelper, outputFieldName, func(expr *planpb.Expr) error { - if len(expr.GetColumnExpr().GetInfo().GetNestedPath()) == 1 && - expr.GetColumnExpr().GetInfo().GetNestedPath()[0] == outputFieldName { - return nil - } - return fmt.Errorf("not support getting subkeys of json field yet") - }) - if err != nil { - log.Info("parse output field name failed", zap.String("field name", outputFieldName)) - return nil, nil, fmt.Errorf("parse output field name failed: %s", outputFieldName) - } - resultFieldNameMap[common.MetaFieldName] = true + if fieldID, ok := allFieldNameMap[outputFieldName]; ok { + if schema.IsFieldLoaded(fieldID) { + resultFieldNameMap[outputFieldName] = true userOutputFieldsMap[outputFieldName] = true } else { - return nil, nil, fmt.Errorf("field %s not exist", outputFieldName) + return nil, nil, fmt.Errorf("field %s is not loaded", outputFieldName) + } + } else { + if schema.EnableDynamicField { + if schema.IsFieldLoaded(dynamicField.GetFieldID()) { + schemaH, err := typeutil.CreateSchemaHelper(schema.CollectionSchema) + if err != nil { + return nil, nil, err + } + err = planparserv2.ParseIdentifier(schemaH, outputFieldName, func(expr *planpb.Expr) error { + if len(expr.GetColumnExpr().GetInfo().GetNestedPath()) == 1 && + expr.GetColumnExpr().GetInfo().GetNestedPath()[0] == outputFieldName { + return nil + } + return fmt.Errorf("not support getting subkeys of json field yet") + }) + if err != nil { + log.Info("parse output field name failed", zap.String("field name", outputFieldName)) + return nil, nil, fmt.Errorf("parse output field name failed: %s", outputFieldName) + } + resultFieldNameMap[common.MetaFieldName] = true + userOutputFieldsMap[outputFieldName] = true + } else { + // TODO after cold field be able to fetched with chunk cache, this check shall be removed + return nil, nil, fmt.Errorf("field %s cannot be returned since dynamic field not loaded", outputFieldName) + } + } else { + return nil, nil, fmt.Errorf("field %s not exist ", outputFieldName) } } } diff --git a/internal/querycoordv2/job/job_load.go b/internal/querycoordv2/job/job_load.go index 234ab219e9..4ade22ee48 100644 --- a/internal/querycoordv2/job/job_load.go +++ b/internal/querycoordv2/job/job_load.go @@ -191,6 +191,7 @@ func (job *LoadCollectionJob) Execute() error { Status: querypb.LoadStatus_Loading, FieldIndexID: req.GetFieldIndexID(), LoadType: querypb.LoadType_LoadCollection, + LoadFields: req.GetLoadFields(), }, CreatedAt: time.Now(), LoadSpan: sp, @@ -371,6 +372,7 @@ func (job *LoadPartitionJob) Execute() error { Status: querypb.LoadStatus_Loading, FieldIndexID: req.GetFieldIndexID(), LoadType: querypb.LoadType_LoadPartition, + LoadFields: req.GetLoadFields(), }, CreatedAt: time.Now(), LoadSpan: sp, diff --git a/internal/querycoordv2/job/job_test.go b/internal/querycoordv2/job/job_test.go index d509d0ac51..276234990d 100644 --- a/internal/querycoordv2/job/job_test.go +++ b/internal/querycoordv2/job/job_test.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/observers" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/util/etcd" "github.com/milvus-io/milvus/pkg/util/merr" @@ -71,6 +72,7 @@ type JobSuite struct { broker *meta.MockBroker nodeMgr *session.NodeManager checkerController *checkers.CheckerController + proxyManager *proxyutil.MockProxyClientManager // Test objects scheduler *Scheduler @@ -140,6 +142,9 @@ func (suite *JobSuite) SetupSuite() { suite.cluster.EXPECT(). ReleasePartitions(mock.Anything, mock.Anything, mock.Anything). Return(merr.Success(), nil).Maybe() + + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *JobSuite) SetupTest() { @@ -199,6 +204,7 @@ func (suite *JobSuite) SetupTest() { suite.targetMgr, suite.targetObserver, suite.checkerController, + suite.proxyManager, ) } diff --git a/internal/querycoordv2/meta/collection_manager.go b/internal/querycoordv2/meta/collection_manager.go index f7fa1a5685..7071556460 100644 --- a/internal/querycoordv2/meta/collection_manager.go +++ b/internal/querycoordv2/meta/collection_manager.go @@ -356,6 +356,17 @@ func (m *CollectionManager) GetFieldIndex(collectionID typeutil.UniqueID) map[in return nil } +func (m *CollectionManager) GetLoadFields(collectionID typeutil.UniqueID) []int64 { + m.rwmutex.RLock() + defer m.rwmutex.RUnlock() + + collection, ok := m.collections[collectionID] + if ok { + return collection.GetLoadFields() + } + return nil +} + func (m *CollectionManager) Exist(collectionID typeutil.UniqueID) bool { m.rwmutex.RLock() defer m.rwmutex.RUnlock() diff --git a/internal/querycoordv2/observers/collection_observer.go b/internal/querycoordv2/observers/collection_observer.go index 6a935a6d52..69235f783c 100644 --- a/internal/querycoordv2/observers/collection_observer.go +++ b/internal/querycoordv2/observers/collection_observer.go @@ -26,14 +26,18 @@ import ( "go.opentelemetry.io/otel/trace" "go.uber.org/zap" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/checkers" "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/eventlog" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -50,6 +54,8 @@ type CollectionObserver struct { loadTasks *typeutil.ConcurrentMap[string, LoadTask] + proxyManager proxyutil.ProxyClientManagerInterface + stopOnce sync.Once } @@ -65,6 +71,7 @@ func NewCollectionObserver( targetMgr meta.TargetManagerInterface, targetObserver *TargetObserver, checherController *checkers.CheckerController, + proxyManager proxyutil.ProxyClientManagerInterface, ) *CollectionObserver { ob := &CollectionObserver{ dist: dist, @@ -74,6 +81,7 @@ func NewCollectionObserver( checkerController: checherController, partitionLoadedCount: make(map[int64]int), loadTasks: typeutil.NewConcurrentMap[string, LoadTask](), + proxyManager: proxyManager, } // Add load task for collection recovery @@ -347,5 +355,20 @@ func (ob *CollectionObserver) observePartitionLoadStatus(ctx context.Context, pa zap.Int32("partitionLoadPercentage", loadPercentage), zap.Int32("collectionLoadPercentage", collectionPercentage), ) + if collectionPercentage == 100 { + ob.invalidateCache(ctx, partition.GetCollectionID()) + } eventlog.Record(eventlog.NewRawEvt(eventlog.Level_Info, fmt.Sprintf("collection %d load percentage update: %d", partition.CollectionID, loadPercentage))) } + +func (ob *CollectionObserver) invalidateCache(ctx context.Context, collectionID int64) { + ctx, cancel := context.WithTimeout(ctx, paramtable.Get().QueryCoordCfg.BrokerTimeout.GetAsDuration(time.Second)) + defer cancel() + err := ob.proxyManager.InvalidateCollectionMetaCache(ctx, &proxypb.InvalidateCollMetaCacheRequest{ + CollectionID: collectionID, + }, proxyutil.SetMsgType(commonpb.MsgType_LoadCollection)) + if err != nil { + log.Warn("failed to invalidate proxy's shard leader cache", zap.Error(err)) + return + } +} diff --git a/internal/querycoordv2/observers/collection_observer_test.go b/internal/querycoordv2/observers/collection_observer_test.go index 6e8d4f541d..6f26a92452 100644 --- a/internal/querycoordv2/observers/collection_observer_test.go +++ b/internal/querycoordv2/observers/collection_observer_test.go @@ -35,6 +35,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/meta" . "github.com/milvus-io/milvus/internal/querycoordv2/params" "github.com/milvus-io/milvus/internal/querycoordv2/session" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -55,12 +56,13 @@ type CollectionObserverSuite struct { nodes []int64 // Mocks - idAllocator func() (int64, error) - etcd *clientv3.Client - kv kv.MetaKv - store metastore.QueryCoordCatalog - broker *meta.MockBroker - cluster *session.MockCluster + idAllocator func() (int64, error) + etcd *clientv3.Client + kv kv.MetaKv + store metastore.QueryCoordCatalog + broker *meta.MockBroker + cluster *session.MockCluster + proxyManager *proxyutil.MockProxyClientManager // Dependencies dist *meta.DistributionManager @@ -162,6 +164,9 @@ func (suite *CollectionObserverSuite) SetupSuite() { 103: 2, } suite.nodes = []int64{1, 2, 3} + + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *CollectionObserverSuite) SetupTest() { @@ -209,6 +214,7 @@ func (suite *CollectionObserverSuite) SetupTest() { suite.targetMgr, suite.targetObserver, suite.checkerController, + suite.proxyManager, ) for _, collection := range suite.collections { diff --git a/internal/querycoordv2/ops_service_test.go b/internal/querycoordv2/ops_service_test.go index 82ef1f696d..9c265620c1 100644 --- a/internal/querycoordv2/ops_service_test.go +++ b/internal/querycoordv2/ops_service_test.go @@ -41,6 +41,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -66,6 +67,7 @@ type OpsServiceSuite struct { jobScheduler *job.Scheduler taskScheduler *task.MockScheduler balancer balance.Balance + proxyManager *proxyutil.MockProxyClientManager distMgr *meta.DistributionManager distController *dist.MockController @@ -77,6 +79,8 @@ type OpsServiceSuite struct { func (suite *OpsServiceSuite) SetupSuite() { paramtable.Init() + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *OpsServiceSuite) SetupTest() { @@ -151,6 +155,7 @@ func (suite *OpsServiceSuite) SetupTest() { suite.server.targetMgr, suite.targetObserver, &checkers.CheckerController{}, + suite.proxyManager, ) suite.server.UpdateStateCode(commonpb.StateCode_Healthy) diff --git a/internal/querycoordv2/server.go b/internal/querycoordv2/server.go index d2c997e339..35398beb9c 100644 --- a/internal/querycoordv2/server.go +++ b/internal/querycoordv2/server.go @@ -410,6 +410,7 @@ func (s *Server) initObserver() { s.targetMgr, s.targetObserver, s.checkerController, + s.proxyClientManager, ) s.replicaObserver = observers.NewReplicaObserver( diff --git a/internal/querycoordv2/server_test.go b/internal/querycoordv2/server_test.go index 78c2fdb89b..ca3a84b839 100644 --- a/internal/querycoordv2/server_test.go +++ b/internal/querycoordv2/server_test.go @@ -587,6 +587,7 @@ func (suite *ServerSuite) hackServer() { suite.server.targetMgr, suite.server.targetObserver, suite.server.checkerController, + suite.server.proxyClientManager, ) suite.broker.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{Schema: &schemapb.CollectionSchema{}}, nil).Maybe() diff --git a/internal/querycoordv2/services.go b/internal/querycoordv2/services.go index ea132a3a91..95ca6cead9 100644 --- a/internal/querycoordv2/services.go +++ b/internal/querycoordv2/services.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querycoordv2/job" @@ -86,6 +87,7 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio collection := s.meta.CollectionManager.GetCollection(collectionID) percentage := s.meta.CollectionManager.CalculateLoadPercentage(collectionID) + loadFields := s.meta.CollectionManager.GetLoadFields(collectionID) refreshProgress := int64(0) if percentage < 0 { if isGetAll { @@ -118,6 +120,9 @@ func (s *Server) ShowCollections(ctx context.Context, req *querypb.ShowCollectio resp.InMemoryPercentages = append(resp.InMemoryPercentages, int64(percentage)) resp.QueryServiceAvailable = append(resp.QueryServiceAvailable, s.checkAnyReplicaAvailable(collectionID)) resp.RefreshProgress = append(resp.RefreshProgress, refreshProgress) + resp.LoadFields = append(resp.LoadFields, &schemapb.LongArray{ + Data: loadFields, + }) } return resp, nil diff --git a/internal/querycoordv2/services_test.go b/internal/querycoordv2/services_test.go index bdb684743c..5f0c600167 100644 --- a/internal/querycoordv2/services_test.go +++ b/internal/querycoordv2/services_test.go @@ -47,6 +47,7 @@ import ( "github.com/milvus-io/milvus/internal/querycoordv2/session" "github.com/milvus-io/milvus/internal/querycoordv2/task" "github.com/milvus-io/milvus/internal/querycoordv2/utils" + "github.com/milvus-io/milvus/internal/util/proxyutil" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/pkg/kv" "github.com/milvus-io/milvus/pkg/util/etcd" @@ -86,6 +87,8 @@ type ServiceSuite struct { distMgr *meta.DistributionManager distController *dist.MockController + proxyManager *proxyutil.MockProxyClientManager + // Test object server *Server } @@ -124,6 +127,9 @@ func (suite *ServiceSuite) SetupSuite() { 1, 2, 3, 4, 5, 101, 102, 103, 104, 105, } + + suite.proxyManager = proxyutil.NewMockProxyClientManager(suite.T()) + suite.proxyManager.EXPECT().InvalidateCollectionMetaCache(mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() } func (suite *ServiceSuite) SetupTest() { @@ -185,6 +191,7 @@ func (suite *ServiceSuite) SetupTest() { suite.targetMgr, suite.targetObserver, &checkers.CheckerController{}, + suite.proxyManager, ) suite.collectionObserver.Start() diff --git a/internal/querycoordv2/task/executor.go b/internal/querycoordv2/task/executor.go index 1a2e3e6edd..3c7941e497 100644 --- a/internal/querycoordv2/task/executor.go +++ b/internal/querycoordv2/task/executor.go @@ -343,6 +343,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { log.Warn("failed to get collection info") return err } + loadFields := ex.meta.GetLoadFields(task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, task.CollectionID()) if err != nil { log.Warn("failed to get partitions of collection") @@ -358,6 +359,7 @@ func (ex *Executor) subscribeChannel(task *ChannelTask, step int) error { task.CollectionID(), collectionInfo.GetDbName(), task.ResourceGroup(), + loadFields, partitions..., ) @@ -649,6 +651,7 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr log.Warn("failed to get collection info", zap.Error(err)) return nil, nil, nil, err } + loadFields := ex.meta.GetLoadFields(task.CollectionID()) partitions, err := utils.GetPartitions(ex.meta.CollectionManager, collectionID) if err != nil { log.Warn("failed to get partitions of collection", zap.Error(err)) @@ -660,6 +663,7 @@ func (ex *Executor) getMetaInfo(ctx context.Context, task Task) (*milvuspb.Descr task.CollectionID(), collectionInfo.GetDbName(), task.ResourceGroup(), + loadFields, partitions..., ) diff --git a/internal/querycoordv2/task/utils.go b/internal/querycoordv2/task/utils.go index 6bf9a289ce..7536e1ab10 100644 --- a/internal/querycoordv2/task/utils.go +++ b/internal/querycoordv2/task/utils.go @@ -182,13 +182,14 @@ func packReleaseSegmentRequest(task *SegmentTask, action *SegmentAction) *queryp } } -func packLoadMeta(loadType querypb.LoadType, collectionID int64, databaseName string, resourceGroup string, partitions ...int64) *querypb.LoadMetaInfo { +func packLoadMeta(loadType querypb.LoadType, collectionID int64, databaseName string, resourceGroup string, loadFields []int64, partitions ...int64) *querypb.LoadMetaInfo { return &querypb.LoadMetaInfo{ LoadType: loadType, CollectionID: collectionID, PartitionIDs: partitions, DbName: databaseName, ResourceGroup: resourceGroup, + LoadFields: loadFields, } } diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 6baf1bd4bb..86c85e0f80 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -146,6 +146,7 @@ type Collection struct { metricType atomic.String // deprecated schema atomic.Pointer[schemapb.CollectionSchema] isGpuIndex bool + loadFields typeutil.Set[int64] refCount *atomic.Uint32 } @@ -227,7 +228,23 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM CCollection NewCollection(const char* schema_proto_blob); */ - schemaBlob, err := proto.Marshal(schema) + + var loadFieldIDs typeutil.Set[int64] + loadSchema := typeutil.Clone(schema) + + // if load fields is specified, do filtering logic + // otherwise use all fields for backward compatibility + if len(loadMetaInfo.GetLoadFields()) > 0 { + loadFieldIDs = typeutil.NewSet(loadMetaInfo.GetLoadFields()...) + loadSchema.Fields = lo.Filter(loadSchema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { + // system field shall always be loaded for now + return loadFieldIDs.Contain(field.GetFieldID()) || common.IsSystemField(field.GetFieldID()) + }) + } else { + loadFieldIDs = typeutil.NewSet(lo.Map(loadSchema.GetFields(), func(field *schemapb.FieldSchema, _ int) int64 { return field.GetFieldID() })...) + } + + schemaBlob, err := proto.Marshal(loadSchema) if err != nil { log.Warn("marshal schema failed", zap.Error(err)) return nil @@ -263,6 +280,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM resourceGroup: loadMetaInfo.GetResourceGroup(), refCount: atomic.NewUint32(0), isGpuIndex: isGpuIndex, + loadFields: loadFieldIDs, } for _, partitionID := range loadMetaInfo.GetPartitionIDs() { coll.partitions.Insert(partitionID) diff --git a/internal/querynodev2/segments/segment_loader.go b/internal/querynodev2/segments/segment_loader.go index ec4127b7d6..f0b3989cb8 100644 --- a/internal/querynodev2/segments/segment_loader.go +++ b/internal/querynodev2/segments/segment_loader.go @@ -206,6 +206,14 @@ func (loader *segmentLoader) Load(ctx context.Context, log.Info("no segment to load") return nil, nil } + coll := loader.manager.Collection.Get(collectionID) + // filter field schema which need to be loaded + for _, info := range segments { + info.BinlogPaths = lo.Filter(info.GetBinlogPaths(), func(fbl *datapb.FieldBinlog, _ int) bool { + return coll.loadFields.Contain(fbl.GetFieldID()) || common.IsSystemField(fbl.GetFieldID()) + }) + } + // Filter out loaded & loading segments infos := loader.prepare(ctx, segmentType, segments...) defer loader.unregister(infos...) @@ -220,7 +228,7 @@ func (loader *segmentLoader) Load(ctx context.Context, var err error var requestResourceResult requestResourceResult - coll := loader.manager.Collection.Get(collectionID) + if !isLazyLoad(coll, segmentType) { // Check memory & storage limit // no need to check resource for lazy load here diff --git a/pkg/common/common.go b/pkg/common/common.go index deb29f97a7..6dc4934506 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -24,9 +24,11 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/log" ) // system field id: @@ -168,6 +170,7 @@ const ( MmapEnabledKey = "mmap.enabled" LazyLoadEnableKey = "lazyload.enabled" PartitionKeyIsolationKey = "partitionkey.isolation" + FieldSkipLoadKey = "field.skipLoad" ) const ( @@ -328,3 +331,35 @@ func CollectionLevelResourceGroups(kvs []*commonpb.KeyValuePair) ([]string, erro return nil, fmt.Errorf("collection property not found: %s", CollectionReplicaNumber) } + +// GetCollectionLoadFields returns the load field ids according to the type params. +func GetCollectionLoadFields(schema *schemapb.CollectionSchema, skipDynamicField bool) []int64 { + return lo.FilterMap(schema.GetFields(), func(field *schemapb.FieldSchema, _ int) (int64, bool) { + // skip system field + if IsSystemField(field.GetFieldID()) { + return field.GetFieldID(), false + } + // skip dynamic field if specified + if field.IsDynamic && skipDynamicField { + return field.GetFieldID(), false + } + + v, err := ShouldFieldBeLoaded(field.GetTypeParams()) + if err != nil { + log.Warn("type param parse skip load failed", zap.Error(err)) + // if configuration cannot be parsed, ignore it and load field + return field.GetFieldID(), true + } + return field.GetFieldID(), v + }) +} + +func ShouldFieldBeLoaded(kvs []*commonpb.KeyValuePair) (bool, error) { + for _, kv := range kvs { + if kv.GetKey() == FieldSkipLoadKey { + val, err := strconv.ParseBool(kv.GetValue()) + return !val, err + } + } + return true, nil +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 11ca8949f1..7e77b782f3 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -149,3 +149,31 @@ func TestCommonPartitionKeyIsolation(t *testing.T) { assert.False(t, res) }) } + +func TestShouldFieldBeLoaded(t *testing.T) { + type testCase struct { + tag string + input []*commonpb.KeyValuePair + expectOutput bool + expectError bool + } + + testcases := []testCase{ + {tag: "no_params", expectOutput: true}, + {tag: "skipload_true", input: []*commonpb.KeyValuePair{{Key: FieldSkipLoadKey, Value: "true"}}, expectOutput: false}, + {tag: "skipload_false", input: []*commonpb.KeyValuePair{{Key: FieldSkipLoadKey, Value: "false"}}, expectOutput: true}, + {tag: "bad_skip_load_value", input: []*commonpb.KeyValuePair{{Key: FieldSkipLoadKey, Value: "abc"}}, expectError: true}, + } + + for _, tc := range testcases { + t.Run(tc.tag, func(t *testing.T) { + result, err := ShouldFieldBeLoaded(tc.input) + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectOutput, result) + } + }) + } +} diff --git a/pkg/go.mod b/pkg/go.mod index 9aa438ac2c..65bd79ee10 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -11,7 +11,6 @@ require ( github.com/confluentinc/confluent-kafka-go v1.9.1 github.com/containerd/cgroups/v3 v3.0.3 github.com/expr-lang/expr v1.15.7 - github.com/golang/protobuf v1.5.4 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/klauspost/compress v1.17.7 github.com/milvus-io/milvus-proto/go-api/v2 v2.3.4-0.20240815123953-6dab6fcd6454 @@ -93,6 +92,7 @@ require ( github.com/godbus/dbus/v5 v5.0.4 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect + github.com/golang/protobuf v1.5.4 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index b808ab5fa0..36f81f11e2 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -257,14 +257,14 @@ type SchemaHelper struct { primaryKeyOffset int partitionKeyOffset int dynamicFieldOffset int + loadFields Set[int64] } -// CreateSchemaHelper returns a new SchemaHelper object -func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error) { +func CreateSchemaHelperWithLoadFields(schema *schemapb.CollectionSchema, loadFields []int64) (*SchemaHelper, error) { if schema == nil { return nil, errors.New("schema is nil") } - schemaHelper := SchemaHelper{schema: schema, nameOffset: make(map[string]int), idOffset: make(map[int64]int), primaryKeyOffset: -1, partitionKeyOffset: -1, dynamicFieldOffset: -1} + schemaHelper := SchemaHelper{schema: schema, nameOffset: make(map[string]int), idOffset: make(map[int64]int), primaryKeyOffset: -1, partitionKeyOffset: -1, dynamicFieldOffset: -1, loadFields: NewSet(loadFields...)} for offset, field := range schema.Fields { if _, ok := schemaHelper.nameOffset[field.Name]; ok { return nil, fmt.Errorf("duplicated fieldName: %s", field.Name) @@ -298,6 +298,11 @@ func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error return &schemaHelper, nil } +// CreateSchemaHelper returns a new SchemaHelper object +func CreateSchemaHelper(schema *schemapb.CollectionSchema) (*SchemaHelper, error) { + return CreateSchemaHelperWithLoadFields(schema, nil) +} + // GetPrimaryKeyField returns the schema of the primary key func (helper *SchemaHelper) GetPrimaryKeyField() (*schemapb.FieldSchema, error) { if helper.primaryKeyOffset == -1 { @@ -338,12 +343,28 @@ func (helper *SchemaHelper) GetFieldFromNameDefaultJSON(fieldName string) (*sche if !ok { return helper.getDefaultJSONField(fieldName) } - return helper.schema.Fields[offset], nil + fieldSchema := helper.schema.Fields[offset] + if !helper.IsFieldLoaded(fieldSchema.GetFieldID()) { + return nil, errors.Newf("field %s is not loaded", fieldSchema) + } + return fieldSchema, nil +} + +// GetFieldFromNameDefaultJSON returns whether is field loaded. +// If load fields is not provided, treated as loaded +func (helper *SchemaHelper) IsFieldLoaded(fieldID int64) bool { + if len(helper.loadFields) == 0 { + return true + } + return helper.loadFields.Contain(fieldID) } func (helper *SchemaHelper) getDefaultJSONField(fieldName string) (*schemapb.FieldSchema, error) { for _, f := range helper.schema.GetFields() { if f.DataType == schemapb.DataType_JSON && f.IsDynamic { + if !helper.IsFieldLoaded(f.GetFieldID()) { + return nil, errors.Newf("field %s is dynamic but dynamic field is not loaded", fieldName) + } return f, nil } }