diff --git a/internal/datacoord/compaction_trigger.go b/internal/datacoord/compaction_trigger.go index 88cfdfd2b5..6813add8f5 100644 --- a/internal/datacoord/compaction_trigger.go +++ b/internal/datacoord/compaction_trigger.go @@ -28,11 +28,8 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/lock" "github.com/milvus-io/milvus/pkg/util/logutil" @@ -302,8 +299,6 @@ func (t *compactionTrigger) allocSignalID() (UniqueID, error) { } func (t *compactionTrigger) getExpectedSegmentSize(collectionID int64) int64 { - indexInfos := t.meta.indexMeta.GetIndexesForCollection(collectionID, "") - ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() collMeta, err := t.handler.GetCollection(ctx, collectionID) @@ -311,19 +306,7 @@ func (t *compactionTrigger) getExpectedSegmentSize(collectionID int64) int64 { log.Warn("failed to get collection", zap.Int64("collectionID", collectionID), zap.Error(err)) return Params.DataCoordCfg.SegmentMaxSize.GetAsInt64() * 1024 * 1024 } - - vectorFields := typeutil.GetVectorFieldSchemas(collMeta.Schema) - fieldIndexTypes := lo.SliceToMap(indexInfos, func(t *model.Index) (int64, indexparamcheck.IndexType) { - return t.FieldID, GetIndexType(t.IndexParams) - }) - vectorFieldsWithDiskIndex := lo.Filter(vectorFields, func(field *schemapb.FieldSchema, _ int) bool { - if indexType, ok := fieldIndexTypes[field.FieldID]; ok { - return indexparamcheck.IsDiskIndex(indexType) - } - return false - }) - - allDiskIndex := len(vectorFields) == len(vectorFieldsWithDiskIndex) + allDiskIndex := t.meta.indexMeta.AreAllDiskIndex(collectionID, collMeta.Schema) if allDiskIndex { // Only if all vector fields index type are DiskANN, recalc segment max size here. return Params.DataCoordCfg.DiskSegmentMaxSize.GetAsInt64() * 1024 * 1024 diff --git a/internal/datacoord/import_checker.go b/internal/datacoord/import_checker.go index 7a73fa9299..1213d0a0c2 100644 --- a/internal/datacoord/import_checker.go +++ b/internal/datacoord/import_checker.go @@ -218,7 +218,8 @@ func (c *importChecker) checkPreImportingJob(job ImportJob) { return } - groups := RegroupImportFiles(job, lacks) + allDiskIndex := c.meta.indexMeta.AreAllDiskIndex(job.GetCollectionID(), job.GetSchema()) + groups := RegroupImportFiles(job, lacks, allDiskIndex) newTasks, err := NewImportTasks(groups, job, c.sm, c.alloc) if err != nil { log.Warn("new import tasks failed", zap.Error(err)) diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index 0439b8da6f..1775c16e50 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -226,13 +226,20 @@ func AssembleImportRequest(task ImportTask, job ImportJob, meta *meta, alloc all }, nil } -func RegroupImportFiles(job ImportJob, files []*datapb.ImportFileStats) [][]*datapb.ImportFileStats { +func RegroupImportFiles(job ImportJob, files []*datapb.ImportFileStats, allDiskIndex bool) [][]*datapb.ImportFileStats { if len(files) == 0 { return nil } + var segmentMaxSize int + if allDiskIndex { + // Only if all vector fields index type are DiskANN, recalc segment max size here. + segmentMaxSize = Params.DataCoordCfg.DiskSegmentMaxSize.GetAsInt() * 1024 * 1024 + } else { + // If some vector fields index type are not DiskANN, recalc segment max size using default policy. + segmentMaxSize = Params.DataCoordCfg.SegmentMaxSize.GetAsInt() * 1024 * 1024 + } isL0Import := importutilv2.IsL0Import(job.GetOptions()) - segmentMaxSize := paramtable.Get().DataCoordCfg.SegmentMaxSize.GetAsInt() * 1024 * 1024 if isL0Import { segmentMaxSize = paramtable.Get().DataNodeCfg.FlushDeleteBufferBytes.GetAsInt() } diff --git a/internal/datacoord/import_util_test.go b/internal/datacoord/import_util_test.go index 50ce606bef..411bfb51ff 100644 --- a/internal/datacoord/import_util_test.go +++ b/internal/datacoord/import_util_test.go @@ -208,7 +208,8 @@ func TestImportUtil_RegroupImportFiles(t *testing.T) { Vchannels: []string{"v0", "v1", "v2", "v3"}, }, } - groups := RegroupImportFiles(job, files) + + groups := RegroupImportFiles(job, files, false) total := 0 for i, fs := range groups { sum := lo.SumBy(fs, func(f *datapb.ImportFileStats) int64 { diff --git a/internal/datacoord/index_meta.go b/internal/datacoord/index_meta.go index 58a9175fc2..41685db08e 100644 --- a/internal/datacoord/index_meta.go +++ b/internal/datacoord/index_meta.go @@ -29,12 +29,14 @@ import ( "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/metastore" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/indexparamcheck" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -926,3 +928,21 @@ func (m *indexMeta) GetUnindexedSegments(collectionID int64, segmentIDs []int64) } return lo.Without(segmentIDs, indexed...) } + +func (m *indexMeta) AreAllDiskIndex(collectionID int64, schema *schemapb.CollectionSchema) bool { + indexInfos := m.GetIndexesForCollection(collectionID, "") + + vectorFields := typeutil.GetVectorFieldSchemas(schema) + fieldIndexTypes := lo.SliceToMap(indexInfos, func(t *model.Index) (int64, indexparamcheck.IndexType) { + return t.FieldID, GetIndexType(t.IndexParams) + }) + vectorFieldsWithDiskIndex := lo.Filter(vectorFields, func(field *schemapb.FieldSchema, _ int) bool { + if indexType, ok := fieldIndexTypes[field.FieldID]; ok { + return indexparamcheck.IsDiskIndex(indexType) + } + return false + }) + + allDiskIndex := len(vectorFields) == len(vectorFieldsWithDiskIndex) + return allDiskIndex +}