diff --git a/internal/datacoord/compaction_policy_clustering.go b/internal/datacoord/compaction_policy_clustering.go index 1bf6989b8c..5a06af6cc6 100644 --- a/internal/datacoord/compaction_policy_clustering.go +++ b/internal/datacoord/compaction_policy_clustering.go @@ -28,6 +28,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/util/clustering" + "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" @@ -147,7 +148,7 @@ func (policy *clusteringCompactionPolicy) triggerOneCollection(ctx context.Conte continue } - collectionTTL, err := getCollectionTTL(collection.Properties) + collectionTTL, err := common.GetCollectionTTLFromMap(collection.Properties, paramtable.Get().CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second)) if err != nil { log.Warn("get collection ttl failed, skip to handle compaction") return make([]CompactionView, 0), 0, err diff --git a/internal/datacoord/compaction_policy_forcemerge.go b/internal/datacoord/compaction_policy_forcemerge.go index 0d11537be1..90d0ec4649 100644 --- a/internal/datacoord/compaction_policy_forcemerge.go +++ b/internal/datacoord/compaction_policy_forcemerge.go @@ -3,6 +3,7 @@ package datacoord import ( "context" "fmt" + "time" "github.com/samber/lo" "go.uber.org/zap" @@ -12,6 +13,7 @@ import ( "github.com/milvus-io/milvus/internal/datacoord/session" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/sessionutil" + "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/util/merr" @@ -78,7 +80,7 @@ func (policy *forceMergeCompactionPolicy) triggerOneCollection( return nil, 0, err } - collectionTTL, err := getCollectionTTL(collection.Properties) + collectionTTL, err := common.GetCollectionTTLFromMap(collection.Properties, paramtable.Get().CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second)) if err != nil { log.Warn("failed to get collection ttl, use default", zap.Error(err)) collectionTTL = 0 diff --git a/internal/datacoord/compaction_policy_single.go b/internal/datacoord/compaction_policy_single.go index e486b5f483..3e45aa4c78 100644 --- a/internal/datacoord/compaction_policy_single.go +++ b/internal/datacoord/compaction_policy_single.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) // singleCompactionPolicy is to compact one segment with too many delta logs @@ -111,7 +112,7 @@ func (policy *singleCompactionPolicy) triggerSegmentSortCompaction( return nil } - collectionTTL, err := getCollectionTTL(collection.Properties) + collectionTTL, err := common.GetCollectionTTLFromMap(collection.Properties, paramtable.Get().CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second)) if err != nil { log.Warn("failed to apply triggerSegmentSortCompaction, get collection ttl failed") return nil @@ -227,7 +228,7 @@ func (policy *singleCompactionPolicy) triggerOneCollection(ctx context.Context, return nil, nil, 0, nil } - collectionTTL, err := getCollectionTTL(collection.Properties) + collectionTTL, err := common.GetCollectionTTLFromMap(collection.Properties, paramtable.Get().CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second)) if err != nil { log.Warn("failed to apply singleCompactionPolicy, get collection ttl failed") return nil, nil, 0, err diff --git a/internal/datacoord/compaction_trigger.go b/internal/datacoord/compaction_trigger.go index ae783cfd63..86209ef83c 100644 --- a/internal/datacoord/compaction_trigger.go +++ b/internal/datacoord/compaction_trigger.go @@ -29,6 +29,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/datacoord/allocator" + "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/util/lifetime" @@ -238,7 +239,7 @@ func isCollectionAutoCompactionEnabled(coll *collectionInfo) bool { } func getCompactTime(ts Timestamp, coll *collectionInfo) (*compactTime, error) { - collectionTTL, err := getCollectionTTL(coll.Properties) + collectionTTL, err := common.GetCollectionTTLFromMap(coll.Properties, paramtable.Get().CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second)) if err != nil { return nil, err } diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index bd8e38a8f9..11e2cc509e 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -875,7 +875,7 @@ func createSortCompactionTask(ctx context.Context, return nil, err } - collectionTTL, err := getCollectionTTL(collection.Properties) + collectionTTL, err := common.GetCollectionTTLFromMap(collection.Properties, paramtable.Get().CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second)) if err != nil { log.Warn("failed to apply triggerSegmentSortCompaction, get collection ttl failed") return nil, err diff --git a/internal/datacoord/util.go b/internal/datacoord/util.go index 383f711712..de6487f672 100644 --- a/internal/datacoord/util.go +++ b/internal/datacoord/util.go @@ -146,20 +146,6 @@ func getZeroTime() time.Time { return t } -// getCollectionTTL returns ttl if collection's ttl is specified, or return global ttl -func getCollectionTTL(properties map[string]string) (time.Duration, error) { - v, ok := properties[common.CollectionTTLConfigKey] - if ok { - ttl, err := strconv.Atoi(v) - if err != nil { - return -1, err - } - return time.Duration(ttl) * time.Second, nil - } - - return Params.CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second), nil -} - func UpdateCompactionSegmentSizeMetrics(segments []*datapb.CompactionSegment) { var totalSize int64 for _, seg := range segments { diff --git a/internal/datacoord/util_test.go b/internal/datacoord/util_test.go index daf10e9fe1..7e2fb25693 100644 --- a/internal/datacoord/util_test.go +++ b/internal/datacoord/util_test.go @@ -146,29 +146,6 @@ func (suite *UtilSuite) TestGetZeroTime() { } } -func (suite *UtilSuite) TestGetCollectionTTL() { - properties1 := map[string]string{ - common.CollectionTTLConfigKey: "3600", - } - - // get ttl from configuration file - ttl, err := getCollectionTTL(properties1) - suite.NoError(err) - suite.Equal(ttl, time.Duration(3600)*time.Second) - - properties2 := map[string]string{ - common.CollectionTTLConfigKey: "error value", - } - // test for parsing configuration failed - ttl, err = getCollectionTTL(properties2) - suite.Error(err) - suite.Equal(int(ttl), -1) - - ttl, err = getCollectionTTL(map[string]string{}) - suite.NoError(err) - suite.Equal(ttl, Params.CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second)) -} - func (suite *UtilSuite) TestGetCollectionAutoCompactionEnabled() { properties := map[string]string{ common.CollectionAutoCompactionKey: "true", diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 2872d42273..db61bbe4a6 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -433,6 +433,12 @@ func (t *createCollectionTask) PreExecute(ctx context.Context) error { return merr.WrapErrParameterInvalidMsg("unknown or invalid IANA Time Zone ID: %s", tz) } + // Validate collection ttl + _, err = common.GetCollectionTTL(t.GetProperties(), -1) + if err != nil { + return merr.WrapErrParameterInvalidMsg("collection ttl property value not valid, parse error: %s", err.Error()) + } + // validate clustering key if err := t.validateClusteringKey(ctx); err != nil { return err @@ -1282,6 +1288,11 @@ func (t *alterCollectionTask) PreExecute(ctx context.Context) error { } } + _, err = common.GetCollectionTTL(t.GetProperties(), -1) + if err != nil { + return merr.WrapErrParameterInvalidMsg("collection ttl properties value not valid, parse error: %s", err.Error()) + } + enabled, _ := common.IsAllowInsertAutoID(t.Properties...) if enabled { primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(collSchema.CollectionSchema) diff --git a/internal/proxy/util.go b/internal/proxy/util.go index a6b6803191..f38c806b04 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -2673,22 +2673,20 @@ func GetBM25FunctionOutputFields(collSchema *schemapb.CollectionSchema) []string return fields } +// getCollectionTTL returns ttl if collection's ttl is specified +// or return global ttl if collection's ttl is not specified +// this is a helper util wrapping common.GetCollectionTTL without returning error func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 { - properties := make(map[string]string) - for _, pair := range pairs { - properties[pair.Key] = pair.Value + defaultTTL := paramtable.Get().CommonCfg.EntityExpirationTTL.GetAsDuration(time.Second) + ttl, err := common.GetCollectionTTL(pairs, defaultTTL) + if err != nil { + log.Error("failed to get collection ttl, use default ttl", zap.Error(err)) + ttl = defaultTTL } - - v, ok := properties[common.CollectionTTLConfigKey] - if ok { - ttl, err := strconv.Atoi(v) - if err != nil { - return 0 - } - return uint64(time.Duration(ttl) * time.Second) + if ttl < 0 { + return 0 } - - return 0 + return uint64(ttl) } // reconstructStructFieldDataCommon reconstructs struct fields from flattened sub-fields diff --git a/pkg/common/common.go b/pkg/common/common.go index 120411d305..d64810b644 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -22,6 +22,7 @@ import ( "math/bits" "strconv" "strings" + "time" "github.com/cockroachdb/errors" "github.com/samber/lo" @@ -604,6 +605,33 @@ func GetStringValue(kvs []*commonpb.KeyValuePair, key string) (result string, ex return kv.GetValue(), true } +func GetCollectionTTL(kvs []*commonpb.KeyValuePair, defaultValue time.Duration) (time.Duration, error) { + value, parseErr, exist := GetInt64Value(kvs, CollectionTTLConfigKey) + if parseErr != nil { + return 0, parseErr + } + + if !exist { + return defaultValue, nil + } + + return time.Duration(value) * time.Second, nil +} + +func GetCollectionTTLFromMap(kvs map[string]string, defaultValue time.Duration) (time.Duration, error) { + value, exist := kvs[CollectionTTLConfigKey] + if !exist { + return defaultValue, nil + } + + ttlSeconds, err := strconv.ParseInt(value, 10, 64) + if err != nil { + return 0, err + } + + return time.Duration(ttlSeconds) * time.Second, nil +} + func CheckNamespace(schema *schemapb.CollectionSchema, namespace *string) error { enabled, _, err := ParseNamespaceProp(schema.Properties...) if err != nil { diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index 9d6bf68a26..3ac2f9d727 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -1,8 +1,10 @@ package common import ( + "math/rand" "strings" "testing" + "time" "github.com/stretchr/testify/assert" @@ -247,3 +249,46 @@ func TestIsDisableFuncRuntimeCheck(t *testing.T) { assert.Error(t, err) assert.False(t, disable) } + +func TestGetCollectionTTL(t *testing.T) { + type testCase struct { + tag string + value string + expect time.Duration + expectErr bool + } + + cases := []testCase{ + {tag: "normal_case", value: "3600", expect: time.Duration(3600) * time.Second, expectErr: false}, + {tag: "error_value", value: "error value", expectErr: true}, + {tag: "out_of_int64_range", value: "10000000000000000000000000000000000000000000000000000000000000000000000000000", expectErr: true}, + {tag: "negative", value: "-1", expect: -1 * time.Second}, + } + + for _, tc := range cases { + t.Run(tc.tag, func(t *testing.T) { + result, err := GetCollectionTTL([]*commonpb.KeyValuePair{{Key: CollectionTTLConfigKey, Value: tc.value}}, 0) + if tc.expectErr { + assert.Error(t, err) + } else { + assert.EqualValues(t, tc.expect, result) + } + result, err = GetCollectionTTLFromMap(map[string]string{CollectionTTLConfigKey: tc.value}, 0) + if tc.expectErr { + assert.Error(t, err) + } else { + assert.EqualValues(t, tc.expect, result) + } + }) + } + + t.Run("not_config", func(t *testing.T) { + randValue := rand.Intn(100) + result, err := GetCollectionTTL([]*commonpb.KeyValuePair{}, time.Duration(randValue)*time.Second) + assert.NoError(t, err) + assert.EqualValues(t, time.Duration(randValue)*time.Second, result) + result, err = GetCollectionTTLFromMap(map[string]string{}, time.Duration(randValue)*time.Second) + assert.NoError(t, err) + assert.EqualValues(t, time.Duration(randValue)*time.Second, result) + }) +}