diff --git a/configs/milvus.yaml b/configs/milvus.yaml index e06e28aa53..7c7ed09f85 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -370,6 +370,7 @@ queryNode: serverMaxRecvSize: 268435456 clientMaxSendSize: 268435456 clientMaxRecvSize: 536870912 + enableSegmentPrune: false # use partition prune function on shard delegator indexCoord: bindIndexNodeMode: diff --git a/internal/parser/planparserv2/parser_visitor.go b/internal/parser/planparserv2/parser_visitor.go index 4bc795103d..da2294226b 100644 --- a/internal/parser/planparserv2/parser_visitor.go +++ b/internal/parser/planparserv2/parser_visitor.go @@ -42,13 +42,14 @@ func (v *ParserVisitor) translateIdentifier(identifier string) (*ExprWithType, e Expr: &planpb.Expr_ColumnExpr{ ColumnExpr: &planpb.ColumnExpr{ Info: &planpb.ColumnInfo{ - FieldId: field.FieldID, - DataType: field.DataType, - IsPrimaryKey: field.IsPrimaryKey, - IsAutoID: field.AutoID, - NestedPath: nestedPath, - IsPartitionKey: field.IsPartitionKey, - ElementType: field.GetElementType(), + FieldId: field.FieldID, + DataType: field.DataType, + IsPrimaryKey: field.IsPrimaryKey, + IsAutoID: field.AutoID, + NestedPath: nestedPath, + IsPartitionKey: field.IsPartitionKey, + IsClusteringKey: field.IsClusteringKey, + ElementType: field.GetElementType(), }, }, }, diff --git a/internal/proto/plan.proto b/internal/proto/plan.proto index fd6710d1f5..62ccae20b5 100644 --- a/internal/proto/plan.proto +++ b/internal/proto/plan.proto @@ -71,6 +71,7 @@ message ColumnInfo { repeated string nested_path = 5; bool is_partition_key = 6; schema.DataType element_type = 7; + bool is_clustering_key = 8; } message ColumnExpr { diff --git a/internal/proxy/expr_checker.go b/internal/proxy/expr_checker.go deleted file mode 100644 index 6c2930fac4..0000000000 --- a/internal/proxy/expr_checker.go +++ /dev/null @@ -1,114 +0,0 @@ -package proxy - -import ( - "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/internal/proto/planpb" -) - -func ParseExprFromPlan(plan *planpb.PlanNode) (*planpb.Expr, error) { - node := plan.GetNode() - - if node == nil { - return nil, errors.New("can't get expr from empty plan node") - } - - var expr *planpb.Expr - switch node := node.(type) { - case *planpb.PlanNode_VectorAnns: - expr = node.VectorAnns.GetPredicates() - case *planpb.PlanNode_Query: - expr = node.Query.GetPredicates() - default: - return nil, errors.New("unsupported plan node type") - } - - return expr, nil -} - -func ParsePartitionKeysFromBinaryExpr(expr *planpb.BinaryExpr) ([]*planpb.GenericValue, bool) { - leftRes, leftInRange := ParsePartitionKeysFromExpr(expr.Left) - RightRes, rightInRange := ParsePartitionKeysFromExpr(expr.Right) - - if expr.Op == planpb.BinaryExpr_LogicalAnd { - // case: partition_key_field in [7, 8] && partition_key > 8 - if len(leftRes)+len(RightRes) > 0 { - leftRes = append(leftRes, RightRes...) - return leftRes, false - } - - // case: other_field > 10 && partition_key_field > 8 - return nil, leftInRange || rightInRange - } - - if expr.Op == planpb.BinaryExpr_LogicalOr { - // case: partition_key_field in [7, 8] or partition_key > 8 - if leftInRange || rightInRange { - return nil, true - } - - // case: partition_key_field in [7, 8] or other_field > 10 - leftRes = append(leftRes, RightRes...) - return leftRes, false - } - - return nil, false -} - -func ParsePartitionKeysFromUnaryExpr(expr *planpb.UnaryExpr) ([]*planpb.GenericValue, bool) { - res, partitionInRange := ParsePartitionKeysFromExpr(expr.GetChild()) - if expr.Op == planpb.UnaryExpr_Not { - // case: partition_key_field not in [7, 8] - if len(res) != 0 { - return nil, true - } - - // case: other_field not in [10] - return nil, partitionInRange - } - - // UnaryOp only includes "Not" for now - return res, partitionInRange -} - -func ParsePartitionKeysFromTermExpr(expr *planpb.TermExpr) ([]*planpb.GenericValue, bool) { - if expr.GetColumnInfo().GetIsPartitionKey() { - return expr.GetValues(), false - } - - return nil, false -} - -func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr) ([]*planpb.GenericValue, bool) { - if expr.GetColumnInfo().GetIsPartitionKey() && expr.GetOp() == planpb.OpType_Equal { - return []*planpb.GenericValue{expr.Value}, false - } - - return nil, true -} - -func ParsePartitionKeysFromExpr(expr *planpb.Expr) ([]*planpb.GenericValue, bool) { - var res []*planpb.GenericValue - partitionKeyInRange := false - switch expr := expr.GetExpr().(type) { - case *planpb.Expr_BinaryExpr: - res, partitionKeyInRange = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr) - case *planpb.Expr_UnaryExpr: - res, partitionKeyInRange = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr) - case *planpb.Expr_TermExpr: - res, partitionKeyInRange = ParsePartitionKeysFromTermExpr(expr.TermExpr) - case *planpb.Expr_UnaryRangeExpr: - res, partitionKeyInRange = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr) - } - - return res, partitionKeyInRange -} - -func ParsePartitionKeys(expr *planpb.Expr) []*planpb.GenericValue { - res, partitionKeyInRange := ParsePartitionKeysFromExpr(expr) - if partitionKeyInRange { - res = nil - } - - return res -} diff --git a/internal/proxy/expr_checker_test.go b/internal/proxy/expr_checker_test.go deleted file mode 100644 index 12c6ed085f..0000000000 --- a/internal/proxy/expr_checker_test.go +++ /dev/null @@ -1,143 +0,0 @@ -package proxy - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/parser/planparserv2" - "github.com/milvus-io/milvus/internal/proto/planpb" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/util/funcutil" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -func TestParsePartitionKeys(t *testing.T) { - prefix := "TestParsePartitionKeys" - collectionName := prefix + funcutil.GenRandomStr() - - fieldName2Type := make(map[string]schemapb.DataType) - fieldName2Type["int64_field"] = schemapb.DataType_Int64 - fieldName2Type["varChar_field"] = schemapb.DataType_VarChar - fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector - schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false) - partitionKeyField := &schemapb.FieldSchema{ - Name: "partition_key_field", - DataType: schemapb.DataType_Int64, - IsPartitionKey: true, - } - schema.Fields = append(schema.Fields, partitionKeyField) - - schemaHelper, err := typeutil.CreateSchemaHelper(schema) - require.NoError(t, err) - fieldID := common.StartOfUserFieldID - for _, field := range schema.Fields { - field.FieldID = int64(fieldID) - fieldID++ - } - - queryInfo := &planpb.QueryInfo{ - Topk: 10, - MetricType: "L2", - SearchParams: "", - RoundDecimal: -1, - } - - type testCase struct { - name string - expr string - expected int - validPartitionKeys []int64 - invalidPartitionKeys []int64 - } - cases := []testCase{ - { - name: "binary_expr_and with term", - expr: "partition_key_field in [7, 8] && int64_field >= 10", - expected: 2, - validPartitionKeys: []int64{7, 8}, - invalidPartitionKeys: []int64{}, - }, - { - name: "binary_expr_and with equal", - expr: "partition_key_field == 7 && int64_field >= 10", - expected: 1, - validPartitionKeys: []int64{7}, - invalidPartitionKeys: []int64{}, - }, - { - name: "binary_expr_and with term2", - expr: "partition_key_field in [7, 8] && int64_field == 10", - expected: 2, - validPartitionKeys: []int64{7, 8}, - invalidPartitionKeys: []int64{10}, - }, - { - name: "binary_expr_and with partition key in range", - expr: "partition_key_field in [7, 8] && partition_key_field > 9", - expected: 2, - validPartitionKeys: []int64{7, 8}, - invalidPartitionKeys: []int64{9}, - }, - { - name: "binary_expr_and with partition key in range2", - expr: "int64_field == 10 && partition_key_field > 9", - expected: 0, - validPartitionKeys: []int64{}, - invalidPartitionKeys: []int64{}, - }, - { - name: "binary_expr_and with term and not", - expr: "partition_key_field in [7, 8] && partition_key_field not in [10, 20]", - expected: 2, - validPartitionKeys: []int64{7, 8}, - invalidPartitionKeys: []int64{10, 20}, - }, - { - name: "binary_expr_or with term and not", - expr: "partition_key_field in [7, 8] or partition_key_field not in [10, 20]", - expected: 0, - validPartitionKeys: []int64{}, - invalidPartitionKeys: []int64{}, - }, - { - name: "binary_expr_or with term and not 2", - expr: "partition_key_field in [7, 8] or int64_field not in [10, 20]", - expected: 2, - validPartitionKeys: []int64{7, 8}, - invalidPartitionKeys: []int64{10, 20}, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - // test search plan - searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo) - assert.NoError(t, err) - expr, err := ParseExprFromPlan(searchPlan) - assert.NoError(t, err) - partitionKeys := ParsePartitionKeys(expr) - assert.Equal(t, tc.expected, len(partitionKeys)) - for _, key := range partitionKeys { - int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val - assert.Contains(t, tc.validPartitionKeys, int64Val) - assert.NotContains(t, tc.invalidPartitionKeys, int64Val) - } - - // test query plan - queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, tc.expr) - assert.NoError(t, err) - expr, err = ParseExprFromPlan(queryPlan) - assert.NoError(t, err) - partitionKeys = ParsePartitionKeys(expr) - assert.Equal(t, tc.expected, len(partitionKeys)) - for _, key := range partitionKeys { - int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val - assert.Contains(t, tc.validPartitionKeys, int64Val) - assert.NotContains(t, tc.invalidPartitionKeys, int64Val) - } - }) - } -} diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 6b1f466778..872169e8ca 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -12,6 +12,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/util/exprutil" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/merr" @@ -92,12 +93,12 @@ func initSearchRequest(ctx context.Context, t *searchTask) error { zap.String("anns field", annsField), zap.Any("query info", queryInfo)) if t.partitionKeyMode { - expr, err := ParseExprFromPlan(plan) + expr, err := exprutil.ParseExprFromPlan(plan) if err != nil { log.Warn("failed to parse expr", zap.Error(err)) return err } - partitionKeys := ParsePartitionKeys(expr) + partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey) hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.collectionName, partitionKeys) if err != nil { log.Warn("failed to assign partition keys", zap.Error(err)) diff --git a/internal/proxy/task_delete.go b/internal/proxy/task_delete.go index f7e4c8210e..a64758b485 100644 --- a/internal/proxy/task_delete.go +++ b/internal/proxy/task_delete.go @@ -21,6 +21,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/exprutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -356,11 +357,11 @@ func (dr *deleteRunner) getStreamingQueryAndDelteFunc(plan *planpb.PlanNode) exe // optimize query when partitionKey on if dr.partitionKeyMode { - expr, err := ParseExprFromPlan(plan) + expr, err := exprutil.ParseExprFromPlan(plan) if err != nil { return err } - partitionKeys := ParsePartitionKeys(expr) + partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey) hashedPartitionNames, err := assignPartitionKeys(ctx, dr.req.GetDbName(), dr.req.GetCollectionName(), partitionKeys) if err != nil { return err diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 8024274473..8b95e7add0 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -19,6 +19,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/types" + "github.com/milvus-io/milvus/internal/util/exprutil" typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" @@ -358,11 +359,11 @@ func (t *queryTask) PreExecute(ctx context.Context) error { if !t.reQuery { partitionNames := t.request.GetPartitionNames() if t.partitionKeyMode { - expr, err := ParseExprFromPlan(t.plan) + expr, err := exprutil.ParseExprFromPlan(t.plan) if err != nil { return err } - partitionKeys := ParsePartitionKeys(expr) + partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey) hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys) if err != nil { return err diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 7c99b268d8..1acd76684f 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -20,6 +20,8 @@ package delegator import ( "context" "fmt" + "path" + "strconv" "sync" "time" @@ -42,6 +44,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/tsafe" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/streamrpc" + "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/metrics" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -49,6 +52,7 @@ import ( "github.com/milvus-io/milvus/pkg/util/funcutil" "github.com/milvus-io/milvus/pkg/util/lifetime" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/timerecord" "github.com/milvus-io/milvus/pkg/util/tsoutil" @@ -115,7 +119,9 @@ type shardDelegator struct { tsCond *sync.Cond latestTsafe *atomic.Uint64 // queryHook - queryHook optimizers.QueryHook + queryHook optimizers.QueryHook + partitionStats map[UniqueID]*storage.PartitionStatsSnapshot + chunkManager storage.ChunkManager } // getLogger returns the zap logger with pre-defined shard attributes. @@ -203,6 +209,9 @@ func (sd *shardDelegator) search(ctx context.Context, req *querypb.SearchRequest log.Warn("failed to optimize search params", zap.Error(err)) return nil, err } + if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() { + PruneSegments(ctx, sd.partitionStats, req.GetReq(), nil, sd.collection.Schema(), sealed, PruneInfo{filterRatio: defaultFilterRatio}) + } tasks, err := organizeSubTask(ctx, req, sealed, growing, sd, sd.modifySearchRequest) if err != nil { @@ -485,12 +494,17 @@ func (sd *shardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) return nil, merr.WrapErrChannelNotAvailable(sd.vchannelName, "distribution is not servcieable") } defer sd.distribution.Unpin(version) - existPartitions := sd.collection.GetPartitions() - growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { - return funcutil.SliceContain(existPartitions, segment.PartitionID) - }) if req.Req.IgnoreGrowing { growing = []SegmentEntry{} + } else { + existPartitions := sd.collection.GetPartitions() + growing = lo.Filter(growing, func(segment SegmentEntry, _ int) bool { + return funcutil.SliceContain(existPartitions, segment.PartitionID) + }) + } + + if paramtable.Get().QueryNodeCfg.EnableSegmentPrune.GetAsBool() { + PruneSegments(ctx, sd.partitionStats, nil, req.GetReq(), sd.collection.Schema(), sealed, PruneInfo{defaultFilterRatio}) } sealedNum := lo.SumBy(sealed, func(item SnapshotItem) int { return len(item.Segments) }) @@ -774,10 +788,72 @@ func (sd *shardDelegator) Close() { sd.lifetime.Wait() } +// As partition stats is an optimization for search/query which is not mandatory for milvus instance, +// loading partitionStats will be a try-best process and will skip+logError when running across errors rather than +// return an error status +func (sd *shardDelegator) maybeReloadPartitionStats(ctx context.Context, partIDs ...UniqueID) { + var partsToReload []UniqueID + if len(partIDs) > 0 { + partsToReload = partIDs + } else { + partsToReload = append(partsToReload, sd.collection.GetPartitions()...) + } + + colID := sd.Collection() + findMaxVersion := func(filePaths []string) (int64, string) { + maxVersion := int64(-1) + maxVersionFilePath := "" + for _, filePath := range filePaths { + versionStr := path.Base(filePath) + version, err := strconv.ParseInt(versionStr, 10, 64) + if err != nil { + continue + } + if version > maxVersion { + maxVersion = version + maxVersionFilePath = filePath + } + } + return maxVersion, maxVersionFilePath + } + for _, partID := range partsToReload { + idPath := metautil.JoinIDPath(colID, partID) + idPath = path.Join(idPath, sd.vchannelName) + statsPathPrefix := path.Join(sd.chunkManager.RootPath(), common.PartitionStatsPath, idPath) + filePaths, _, err := sd.chunkManager.ListWithPrefix(ctx, statsPathPrefix, true) + if err != nil { + log.Error("Skip initializing partition stats for failing to list files with prefix", + zap.String("statsPathPrefix", statsPathPrefix)) + continue + } + maxVersion, maxVersionFilePath := findMaxVersion(filePaths) + if maxVersion < 0 { + log.Info("failed to find valid partition stats file for partition", zap.Int64("partitionID", partID)) + continue + } + partStats, exists := sd.partitionStats[partID] + if !exists || (exists && partStats.GetVersion() < maxVersion) { + statsBytes, err := sd.chunkManager.Read(ctx, maxVersionFilePath) + if err != nil { + log.Error("failed to read stats file from object storage", zap.String("path", maxVersionFilePath)) + continue + } + partStats, err := storage.DeserializePartitionsStatsSnapshot(statsBytes) + if err != nil { + log.Error("failed to parse partition stats from bytes", zap.Int("bytes_length", len(statsBytes))) + continue + } + sd.partitionStats[partID] = partStats + partStats.SetVersion(maxVersion) + log.Info("Updated partitionStats for partition", zap.Int64("partitionID", partID)) + } + } +} + // NewShardDelegator creates a new ShardDelegator instance with all fields initialized. func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID UniqueID, channel string, version int64, workerManager cluster.Manager, manager *segments.Manager, tsafeManager tsafe.Manager, loader segments.Loader, - factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, + factory msgstream.Factory, startTs uint64, queryHook optimizers.QueryHook, chunkManager storage.ChunkManager, ) (ShardDelegator, error) { log := log.Ctx(ctx).With(zap.Int64("collectionID", collectionID), zap.Int64("replicaID", replicaID), @@ -812,6 +888,8 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni loader: loader, factory: factory, queryHook: queryHook, + chunkManager: chunkManager, + partitionStats: make(map[UniqueID]*storage.PartitionStatsSnapshot), } m := sync.Mutex{} sd.tsCond = sync.NewCond(&m) @@ -819,5 +897,6 @@ func NewShardDelegator(ctx context.Context, collectionID UniqueID, replicaID Uni go sd.watchTSafe() } log.Info("finish build new shardDelegator") + sd.maybeReloadPartitionStats(ctx) return sd, nil } diff --git a/internal/querynodev2/delegator/delegator_data.go b/internal/querynodev2/delegator/delegator_data.go index 6baed214b5..902bddd0ba 100644 --- a/internal/querynodev2/delegator/delegator_data.go +++ b/internal/querynodev2/delegator/delegator_data.go @@ -475,6 +475,12 @@ func (sd *shardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSeg // alter distribution sd.distribution.AddDistributions(entries...) + partStatsToReload := make([]UniqueID, 0) + lo.ForEach(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) { + partStatsToReload = append(partStatsToReload, info.PartitionID) + }) + sd.maybeReloadPartitionStats(ctx, partStatsToReload...) + return nil } @@ -850,7 +856,14 @@ func (sd *shardDelegator) ReleaseSegments(ctx context.Context, req *querypb.Rele if hasLevel0 { sd.GenerateLevel0DeletionCache() } - + partitionsToReload := make([]UniqueID, 0) + lo.ForEach(req.GetSegmentIDs(), func(segmentID int64, _ int) { + segment := sd.segmentManager.Get(segmentID) + if segment != nil { + partitionsToReload = append(partitionsToReload, segment.Partition()) + } + }) + sd.maybeReloadPartitionStats(ctx, partitionsToReload...) return nil } diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index e852a0f54f..e3b1b57d49 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -18,6 +18,8 @@ package delegator import ( "context" + "path" + "strconv" "testing" bloom "github.com/bits-and-blooms/bloom/v3" @@ -41,6 +43,7 @@ import ( "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/commonpbutil" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/metautil" "github.com/milvus-io/milvus/pkg/util/metric" "github.com/milvus-io/milvus/pkg/util/paramtable" ) @@ -58,7 +61,9 @@ type DelegatorDataSuite struct { loader *segments.MockLoader mq *msgstream.MockMsgStream - delegator *shardDelegator + delegator *shardDelegator + rootPath string + chunkManager storage.ChunkManager } func (s *DelegatorDataSuite) SetupSuite() { @@ -126,16 +131,19 @@ func (s *DelegatorDataSuite) SetupTest() { }, }, }, &querypb.LoadMetaInfo{ - LoadType: querypb.LoadType_LoadCollection, + LoadType: querypb.LoadType_LoadCollection, + PartitionIDs: []int64{1001, 1002}, }) s.mq = &msgstream.MockMsgStream{} - + s.rootPath = s.Suite.T().Name() + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath) + s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background()) delegator, err := NewShardDelegator(context.Background(), s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader, &msgstream.MockMqFactory{ NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { return s.mq, nil }, - }, 10000, nil) + }, 10000, nil, s.chunkManager) s.Require().NoError(err) sd, ok := delegator.(*shardDelegator) s.Require().True(ok) @@ -609,7 +617,7 @@ func (s *DelegatorDataSuite) TestLoadSegments() { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { return s.mq, nil }, - }, 10000, nil) + }, 10000, nil, nil) s.NoError(err) growing0 := segments.NewMockSegment(s.T()) @@ -968,6 +976,78 @@ func (s *DelegatorDataSuite) TestReleaseSegment() { s.NoError(err) } +func (s *DelegatorDataSuite) TestLoadPartitionStats() { + segStats := make(map[UniqueID]storage.SegmentStats) + centroid := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0} + var segID int64 = 1 + rows := 1990 + { + // p1 stats + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: 1, + Type: schemapb.DataType_Int64, + Max: storage.NewInt64FieldValue(200), + Min: storage.NewInt64FieldValue(100), + } + fieldStat2 := storage.FieldStats{ + FieldID: 2, + Type: schemapb.DataType_Int64, + Max: storage.NewInt64FieldValue(400), + Min: storage.NewInt64FieldValue(300), + } + fieldStat3 := storage.FieldStats{ + FieldID: 3, + Type: schemapb.DataType_FloatVector, + Centroids: []storage.VectorFieldValue{ + &storage.FloatVectorFieldValue{ + Value: centroid, + }, + &storage.FloatVectorFieldValue{ + Value: centroid, + }, + }, + } + fieldStats = append(fieldStats, fieldStat1) + fieldStats = append(fieldStats, fieldStat2) + fieldStats = append(fieldStats, fieldStat3) + segStats[segID] = *storage.NewSegmentStats(fieldStats, rows) + } + partitionStats1 := &storage.PartitionStatsSnapshot{ + SegmentStats: segStats, + } + statsData1, err := storage.SerializePartitionStatsSnapshot(partitionStats1) + s.NoError(err) + partitionID1 := int64(1001) + idPath1 := metautil.JoinIDPath(s.collectionID, partitionID1) + idPath1 = path.Join(idPath1, s.delegator.vchannelName) + statsPath1 := path.Join(s.chunkManager.RootPath(), common.PartitionStatsPath, idPath1, strconv.Itoa(1)) + s.chunkManager.Write(context.Background(), statsPath1, statsData1) + defer s.chunkManager.Remove(context.Background(), statsPath1) + + // reload and check partition stats + s.delegator.maybeReloadPartitionStats(context.Background()) + s.Equal(1, len(s.delegator.partitionStats)) + s.NotNil(s.delegator.partitionStats[partitionID1]) + p1Stats := s.delegator.partitionStats[partitionID1] + s.Equal(int64(1), p1Stats.GetVersion()) + s.Equal(rows, p1Stats.SegmentStats[segID].NumRows) + s.Equal(3, len(p1Stats.SegmentStats[segID].FieldStats)) + + // judge vector stats + vecFieldStats := p1Stats.SegmentStats[segID].FieldStats[2] + s.Equal(2, len(vecFieldStats.Centroids)) + s.Equal(8, len(vecFieldStats.Centroids[0].GetValue().([]float32))) + + // judge scalar stats + fieldStats1 := p1Stats.SegmentStats[segID].FieldStats[0] + s.Equal(int64(100), fieldStats1.Min.GetValue().(int64)) + s.Equal(int64(200), fieldStats1.Max.GetValue().(int64)) + fieldStats2 := p1Stats.SegmentStats[segID].FieldStats[1] + s.Equal(int64(300), fieldStats2.Min.GetValue().(int64)) + s.Equal(int64(400), fieldStats2.Max.GetValue().(int64)) +} + func (s *DelegatorDataSuite) TestSyncTargetVersion() { for i := int64(0); i < 5; i++ { ms := &segments.MockSegment{} diff --git a/internal/querynodev2/delegator/delegator_test.go b/internal/querynodev2/delegator/delegator_test.go index 2cfcd1115a..a5d5264e59 100644 --- a/internal/querynodev2/delegator/delegator_test.go +++ b/internal/querynodev2/delegator/delegator_test.go @@ -40,6 +40,7 @@ import ( "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tsafe" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/streamrpc" "github.com/milvus-io/milvus/pkg/common" "github.com/milvus-io/milvus/pkg/mq/msgstream" @@ -64,7 +65,9 @@ type DelegatorSuite struct { loader *segments.MockLoader mq *msgstream.MockMsgStream - delegator ShardDelegator + delegator ShardDelegator + chunkManager storage.ChunkManager + rootPath string } func (s *DelegatorSuite) SetupSuite() { @@ -154,6 +157,11 @@ func (s *DelegatorSuite) SetupTest() { }) s.mq = &msgstream.MockMsgStream{} + s.rootPath = "delegator_test" + + // init chunkManager + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), s.rootPath) + s.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(context.Background()) var err error // s.delegator, err = NewShardDelegator(s.collectionID, s.replicaID, s.vchannelName, s.version, s.workerManager, s.manager, s.tsafeManager, s.loader) @@ -161,7 +169,7 @@ func (s *DelegatorSuite) SetupTest() { NewMsgStreamFunc: func(_ context.Context) (msgstream.MsgStream, error) { return s.mq, nil }, - }, 10000, nil) + }, 10000, nil, s.chunkManager) s.Require().NoError(err) } diff --git a/internal/querynodev2/delegator/segment_pruner.go b/internal/querynodev2/delegator/segment_pruner.go new file mode 100644 index 0000000000..7b6bd9acbb --- /dev/null +++ b/internal/querynodev2/delegator/segment_pruner.go @@ -0,0 +1,228 @@ +package delegator + +import ( + "context" + "sort" + "strconv" + + "github.com/golang/protobuf/proto" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/clustering" + "github.com/milvus-io/milvus/internal/util/exprutil" + "github.com/milvus-io/milvus/internal/util/typeutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/distance" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +const defaultFilterRatio float64 = 0.5 + +type PruneInfo struct { + filterRatio float64 +} + +func PruneSegments(ctx context.Context, + partitionStats map[UniqueID]*storage.PartitionStatsSnapshot, + searchReq *internalpb.SearchRequest, + queryReq *internalpb.RetrieveRequest, + schema *schemapb.CollectionSchema, + sealedSegments []SnapshotItem, + info PruneInfo, +) { + log := log.Ctx(ctx) + // 1. calculate filtered segments + filteredSegments := make(map[UniqueID]struct{}, 0) + clusteringKeyField := typeutil.GetClusteringKeyField(schema.Fields) + if clusteringKeyField == nil { + return + } + if searchReq != nil { + // parse searched vectors + var vectorsHolder commonpb.PlaceholderGroup + err := proto.Unmarshal(searchReq.GetPlaceholderGroup(), &vectorsHolder) + if err != nil || len(vectorsHolder.GetPlaceholders()) == 0 { + return + } + vectorsBytes := vectorsHolder.GetPlaceholders()[0].GetValues() + // parse dim + dimStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.DimKey, clusteringKeyField.GetTypeParams()) + if err != nil { + return + } + dimValue, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return + } + for _, partID := range searchReq.GetPartitionIDs() { + partStats := partitionStats[partID] + FilterSegmentsByVector(partStats, searchReq, vectorsBytes, dimValue, clusteringKeyField, filteredSegments, info.filterRatio) + } + } else if queryReq != nil { + // 0. parse expr from plan + plan := planpb.PlanNode{} + err := proto.Unmarshal(queryReq.GetSerializedExprPlan(), &plan) + if err != nil { + log.Error("failed to unmarshall serialized expr from bytes, failed the operation") + return + } + expr, err := exprutil.ParseExprFromPlan(&plan) + if err != nil { + log.Error("failed to parse expr from plan, failed the operation") + return + } + targetRanges, matchALL := exprutil.ParseRanges(expr, exprutil.ClusteringKey) + if matchALL || targetRanges == nil { + return + } + for _, partID := range queryReq.GetPartitionIDs() { + partStats := partitionStats[partID] + FilterSegmentsOnScalarField(partStats, targetRanges, clusteringKeyField, filteredSegments) + } + } + + // 2. remove filtered segments from sealed segment list + if len(filteredSegments) > 0 { + totalSegNum := 0 + for idx, item := range sealedSegments { + newSegments := make([]SegmentEntry, 0) + totalSegNum += len(item.Segments) + for _, segment := range item.Segments { + if _, ok := filteredSegments[segment.SegmentID]; !ok { + newSegments = append(newSegments, segment) + } + } + item.Segments = newSegments + sealedSegments[idx] = item + } + log.Debug("Pruned segment for search/query", + zap.Int("pruned_segment_num", len(filteredSegments)), + zap.Int("total_segment_num", totalSegNum), + ) + } +} + +type segmentDisStruct struct { + segmentID UniqueID + distance float32 + rows int // for keep track of sufficiency of topK +} + +func FilterSegmentsByVector(partitionStats *storage.PartitionStatsSnapshot, + searchReq *internalpb.SearchRequest, + vectorBytes [][]byte, + dim int64, + keyField *schemapb.FieldSchema, + filteredSegments map[UniqueID]struct{}, + filterRatio float64, +) { + // 1. calculate vectors' distances + neededSegments := make(map[UniqueID]struct{}) + for _, vecBytes := range vectorBytes { + segmentsToSearch := make([]segmentDisStruct, 0) + for segId, segStats := range partitionStats.SegmentStats { + // here, we do not skip needed segments required by former query vector + // meaning that repeated calculation will be carried and the larger the nq is + // the more segments have to be included and prune effect will decline + // 1. calculate distances from centroids + for _, fieldStat := range segStats.FieldStats { + if fieldStat.FieldID == keyField.GetFieldID() { + if fieldStat.Centroids == nil || len(fieldStat.Centroids) == 0 { + neededSegments[segId] = struct{}{} + break + } + var dis []float32 + var disErr error + switch keyField.GetDataType() { + case schemapb.DataType_FloatVector: + dis, disErr = clustering.CalcVectorDistance(dim, keyField.GetDataType(), + vecBytes, fieldStat.Centroids[0].GetValue().([]float32), searchReq.GetMetricType()) + default: + neededSegments[segId] = struct{}{} + disErr = merr.WrapErrParameterInvalid(schemapb.DataType_FloatVector, keyField.GetDataType(), + "Currently, pruning by cluster only support float_vector type") + } + // currently, we only support float vector and only one center one segment + if disErr != nil { + neededSegments[segId] = struct{}{} + break + } + segmentsToSearch = append(segmentsToSearch, segmentDisStruct{ + segmentID: segId, + distance: dis[0], + rows: segStats.NumRows, + }) + break + } + } + } + // 2. sort the distances + switch searchReq.GetMetricType() { + case distance.L2: + sort.SliceStable(segmentsToSearch, func(i, j int) bool { + return segmentsToSearch[i].distance < segmentsToSearch[j].distance + }) + case distance.IP, distance.COSINE: + sort.SliceStable(segmentsToSearch, func(i, j int) bool { + return segmentsToSearch[i].distance > segmentsToSearch[j].distance + }) + } + + // 3. filtered non-target segments + segmentCount := len(segmentsToSearch) + targetSegNum := int(float64(segmentCount) * filterRatio) + optimizedRowCount := 0 + // set the last n - targetSegNum as being filtered + for i := 0; i < segmentCount; i++ { + optimizedRowCount += segmentsToSearch[i].rows + neededSegments[segmentsToSearch[i].segmentID] = struct{}{} + if int64(optimizedRowCount) >= searchReq.GetTopk() && i >= targetSegNum { + break + } + } + } + + // 3. set not needed segments as removed + for segId := range partitionStats.SegmentStats { + if _, ok := neededSegments[segId]; !ok { + filteredSegments[segId] = struct{}{} + } + } +} + +func FilterSegmentsOnScalarField(partitionStats *storage.PartitionStatsSnapshot, + targetRanges []*exprutil.PlanRange, + keyField *schemapb.FieldSchema, + filteredSegments map[UniqueID]struct{}, +) { + // 1. try to filter segments + overlap := func(min storage.ScalarFieldValue, max storage.ScalarFieldValue) bool { + for _, tRange := range targetRanges { + switch keyField.DataType { + case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32, schemapb.DataType_Int64: + targetRange := tRange.ToIntRange() + statRange := exprutil.NewIntRange(min.GetValue().(int64), max.GetValue().(int64), true, true) + return exprutil.IntRangeOverlap(targetRange, statRange) + case schemapb.DataType_String, schemapb.DataType_VarChar: + targetRange := tRange.ToStrRange() + statRange := exprutil.NewStrRange(min.GetValue().(string), max.GetValue().(string), true, true) + return exprutil.StrRangeOverlap(targetRange, statRange) + } + } + return false + } + for segID, segStats := range partitionStats.SegmentStats { + for _, fieldStat := range segStats.FieldStats { + if keyField.FieldID == fieldStat.FieldID && !overlap(fieldStat.Min, fieldStat.Max) { + filteredSegments[segID] = struct{}{} + } + } + } +} diff --git a/internal/querynodev2/delegator/segment_pruner_test.go b/internal/querynodev2/delegator/segment_pruner_test.go new file mode 100644 index 0000000000..cdfeb8a304 --- /dev/null +++ b/internal/querynodev2/delegator/segment_pruner_test.go @@ -0,0 +1,422 @@ +package delegator + +import ( + "context" + "testing" + + "github.com/golang/protobuf/proto" + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/internal/util/clustering" + "github.com/milvus-io/milvus/internal/util/testutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type SegmentPrunerSuite struct { + suite.Suite + partitionStats map[UniqueID]*storage.PartitionStatsSnapshot + schema *schemapb.CollectionSchema + collectionName string + primaryFieldName string + clusterKeyFieldName string + autoID bool + targetPartition int64 + dim int + sealedSegments []SnapshotItem +} + +func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, + clusterKeyFieldType schemapb.DataType, +) { + sps.collectionName = "test_segment_prune" + sps.primaryFieldName = "pk" + sps.clusterKeyFieldName = clusterKeyFieldName + sps.autoID = true + sps.dim = 8 + + fieldName2DataType := make(map[string]schemapb.DataType) + fieldName2DataType[sps.primaryFieldName] = schemapb.DataType_Int64 + fieldName2DataType[sps.clusterKeyFieldName] = clusterKeyFieldType + fieldName2DataType["info"] = schemapb.DataType_VarChar + fieldName2DataType["age"] = schemapb.DataType_Int32 + fieldName2DataType["vec"] = schemapb.DataType_FloatVector + + sps.schema = testutil.ConstructCollectionSchemaWithKeys(sps.collectionName, + fieldName2DataType, + sps.primaryFieldName, + "", + sps.clusterKeyFieldName, + false, + sps.dim) + + var clusteringKeyFieldID int64 = 0 + for _, field := range sps.schema.GetFields() { + if field.IsClusteringKey { + clusteringKeyFieldID = field.FieldID + break + } + } + centroids1 := []storage.VectorFieldValue{ + &storage.FloatVectorFieldValue{ + Value: []float32{0.6951474, 0.45225978, 0.51508516, 0.24968886, 0.6085484, 0.964968, 0.32239532, 0.7771577}, + }, + } + centroids2 := []storage.VectorFieldValue{ + &storage.FloatVectorFieldValue{ + Value: []float32{0.12345678, 0.23456789, 0.34567890, 0.45678901, 0.56789012, 0.67890123, 0.78901234, 0.89012345}, + }, + } + centroids3 := []storage.VectorFieldValue{ + &storage.FloatVectorFieldValue{ + Value: []float32{0.98765432, 0.87654321, 0.76543210, 0.65432109, 0.54321098, 0.43210987, 0.32109876, 0.21098765}, + }, + } + centroids4 := []storage.VectorFieldValue{ + &storage.FloatVectorFieldValue{ + Value: []float32{0.11111111, 0.22222222, 0.33333333, 0.44444444, 0.55555555, 0.66666666, 0.77777777, 0.88888888}, + }, + } + + // init partition stats + // here, for convenience, we set up both min/max and Centroids + // into the same struct, in the real user cases, a field stat + // can either contain min&&max or centroids + segStats := make(map[UniqueID]storage.SegmentStats) + switch clusterKeyFieldType { + case schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8: + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int64, + Min: storage.NewInt64FieldValue(100), + Max: storage.NewInt64FieldValue(200), + Centroids: centroids1, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[1] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int64, + Min: storage.NewInt64FieldValue(100), + Max: storage.NewInt64FieldValue(400), + Centroids: centroids2, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[2] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int64, + Min: storage.NewInt64FieldValue(600), + Max: storage.NewInt64FieldValue(900), + Centroids: centroids3, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[3] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int64, + Min: storage.NewInt64FieldValue(500), + Max: storage.NewInt64FieldValue(1000), + Centroids: centroids4, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[4] = *storage.NewSegmentStats(fieldStats, 80) + } + default: + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_VarChar, + Min: storage.NewStringFieldValue("ab"), + Max: storage.NewStringFieldValue("bbc"), + Centroids: centroids1, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[1] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_VarChar, + Min: storage.NewStringFieldValue("hhh"), + Max: storage.NewStringFieldValue("jjx"), + Centroids: centroids2, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[2] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_VarChar, + Min: storage.NewStringFieldValue("kkk"), + Max: storage.NewStringFieldValue("lmn"), + Centroids: centroids3, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[3] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_VarChar, + Min: storage.NewStringFieldValue("oo2"), + Max: storage.NewStringFieldValue("pptt"), + Centroids: centroids4, + } + fieldStats = append(fieldStats, fieldStat1) + segStats[4] = *storage.NewSegmentStats(fieldStats, 80) + } + } + sps.partitionStats = make(map[UniqueID]*storage.PartitionStatsSnapshot) + sps.targetPartition = 11111 + sps.partitionStats[sps.targetPartition] = &storage.PartitionStatsSnapshot{ + SegmentStats: segStats, + } + + sealedSegments := make([]SnapshotItem, 0) + item1 := SnapshotItem{ + NodeID: 1, + Segments: []SegmentEntry{ + { + NodeID: 1, + SegmentID: 1, + }, + { + NodeID: 1, + SegmentID: 2, + }, + }, + } + item2 := SnapshotItem{ + NodeID: 2, + Segments: []SegmentEntry{ + { + NodeID: 2, + SegmentID: 3, + }, + { + NodeID: 2, + SegmentID: 4, + }, + }, + } + sealedSegments = append(sealedSegments, item1) + sealedSegments = append(sealedSegments, item2) + sps.sealedSegments = sealedSegments +} + +func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() { + sps.SetupForClustering("age", schemapb.DataType_Int32) + targetPartitions := make([]UniqueID, 0) + targetPartitions = append(targetPartitions, sps.targetPartition) + { + // test for exact values + testSegments := make([]SnapshotItem, len(sps.sealedSegments)) + copy(testSegments, sps.sealedSegments) + exprStr := "age==156" + schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + PartitionIDs: targetPartitions, + } + PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio}) + sps.Equal(2, len(testSegments[0].Segments)) + sps.Equal(0, len(testSegments[1].Segments)) + } + { + // test for range one expr part + testSegments := make([]SnapshotItem, len(sps.sealedSegments)) + copy(testSegments, sps.sealedSegments) + exprStr := "age>=700" + schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + PartitionIDs: targetPartitions, + } + PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio}) + sps.Equal(0, len(testSegments[0].Segments)) + sps.Equal(2, len(testSegments[1].Segments)) + } + { + // test for unlogical binary range + testSegments := make([]SnapshotItem, len(sps.sealedSegments)) + copy(testSegments, sps.sealedSegments) + exprStr := "age>=700 and age<=500" + schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + PartitionIDs: targetPartitions, + } + PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio}) + sps.Equal(2, len(testSegments[0].Segments)) + sps.Equal(2, len(testSegments[1].Segments)) + } + { + // test for unlogical binary range + testSegments := make([]SnapshotItem, len(sps.sealedSegments)) + copy(testSegments, sps.sealedSegments) + exprStr := "age>=500 and age<=550" + schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + PartitionIDs: targetPartitions, + } + PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio}) + sps.Equal(0, len(testSegments[0].Segments)) + sps.Equal(1, len(testSegments[1].Segments)) + } +} + +func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarStrField() { + sps.SetupForClustering("info", schemapb.DataType_VarChar) + targetPartitions := make([]UniqueID, 0) + targetPartitions = append(targetPartitions, sps.targetPartition) + { + // test for exact str values + testSegments := make([]SnapshotItem, len(sps.sealedSegments)) + copy(testSegments, sps.sealedSegments) + exprStr := `info=="rag"` + schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + PartitionIDs: targetPartitions, + } + PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio}) + sps.Equal(0, len(testSegments[0].Segments)) + sps.Equal(0, len(testSegments[1].Segments)) + // there should be no segments fulfilling the info=="rag" + } + { + // test for exact str values + testSegments := make([]SnapshotItem, len(sps.sealedSegments)) + copy(testSegments, sps.sealedSegments) + exprStr := `info=="kpl"` + schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + PartitionIDs: targetPartitions, + } + PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio}) + sps.Equal(0, len(testSegments[0].Segments)) + sps.Equal(1, len(testSegments[1].Segments)) + // there should be no segments fulfilling the info=="rag" + } + { + // test for unary str values + testSegments := make([]SnapshotItem, len(sps.sealedSegments)) + copy(testSegments, sps.sealedSegments) + exprStr := `info<="less"` + schemaHelper, _ := typeutil.CreateSchemaHelper(sps.schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + PartitionIDs: targetPartitions, + } + PruneSegments(context.TODO(), sps.partitionStats, nil, queryReq, sps.schema, testSegments, PruneInfo{defaultFilterRatio}) + sps.Equal(2, len(testSegments[0].Segments)) + sps.Equal(1, len(testSegments[1].Segments)) + // there should be no segments fulfilling the info=="rag" + } +} + +func vector2Placeholder(vectors [][]float32) *commonpb.PlaceholderValue { + ph := &commonpb.PlaceholderValue{ + Tag: "$0", + Values: make([][]byte, 0, len(vectors)), + } + if len(vectors) == 0 { + return ph + } + + ph.Type = commonpb.PlaceholderType_FloatVector + for _, vector := range vectors { + ph.Values = append(ph.Values, clustering.SerializeFloatVector(vector)) + } + return ph +} + +func (sps *SegmentPrunerSuite) TestPruneSegmentsByVectorField() { + sps.SetupForClustering("vec", schemapb.DataType_FloatVector) + + vector1 := []float32{0.8877872002188053, 0.6131822285635065, 0.8476814632326242, 0.6645877829359371, 0.9962627712600025, 0.8976183052440327, 0.41941169325798844, 0.7554387854258499} + vector2 := []float32{0.8644394874390322, 0.023327886647378615, 0.08330118483461302, 0.7068040179963112, 0.6983994910799851, 0.5562075958994153, 0.3288536247938002, 0.07077341010237759} + vectors := [][]float32{vector1, vector2} + + phg := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + vector2Placeholder(vectors), + }, + } + bs, _ := proto.Marshal(phg) + // test for L2 metrics + req := &internalpb.SearchRequest{ + MetricType: "L2", + PlaceholderGroup: bs, + PartitionIDs: []UniqueID{sps.targetPartition}, + Topk: 100, + } + + PruneSegments(context.TODO(), sps.partitionStats, req, nil, sps.schema, sps.sealedSegments, PruneInfo{0.25}) + sps.Equal(1, len(sps.sealedSegments[0].Segments)) + sps.Equal(int64(1), sps.sealedSegments[0].Segments[0].SegmentID) + sps.Equal(1, len(sps.sealedSegments[1].Segments)) + sps.Equal(int64(3), sps.sealedSegments[1].Segments[0].SegmentID) + + // test for IP metrics + req = &internalpb.SearchRequest{ + MetricType: "IP", + PlaceholderGroup: bs, + PartitionIDs: []UniqueID{sps.targetPartition}, + Topk: 100, + } + + PruneSegments(context.TODO(), sps.partitionStats, req, nil, sps.schema, sps.sealedSegments, PruneInfo{0.25}) + sps.Equal(1, len(sps.sealedSegments[0].Segments)) + sps.Equal(int64(1), sps.sealedSegments[0].Segments[0].SegmentID) + sps.Equal(1, len(sps.sealedSegments[1].Segments)) + sps.Equal(int64(3), sps.sealedSegments[1].Segments[0].SegmentID) +} + +func TestSegmentPrunerSuite(t *testing.T) { + suite.Run(t, new(SegmentPrunerSuite)) +} diff --git a/internal/querynodev2/segments/reduce_test.go b/internal/querynodev2/segments/reduce_test.go index a310126b9c..625ec3e4a8 100644 --- a/internal/querynodev2/segments/reduce_test.go +++ b/internal/querynodev2/segments/reduce_test.go @@ -60,7 +60,7 @@ func (suite *ReduceSuite) SetupTest() { msgLength := 100 suite.rootPath = suite.T().Name() - chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) initcore.InitRemoteChunkManager(paramtable.Get()) diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index 02a0eae1bf..0b549a63e0 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -61,7 +61,7 @@ func (suite *RetrieveSuite) SetupTest() { msgLength := 100 suite.rootPath = suite.T().Name() - chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) initcore.InitRemoteChunkManager(paramtable.Get()) diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index 3ec906ad49..e7e335a5c5 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -78,7 +78,7 @@ func (suite *SegmentLoaderSuite) SetupTest() { // TODO:: cpp chunk manager not support local chunk manager // suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath( // fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63()))) - chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) suite.loader = NewLoader(suite.manager, suite.chunkManager) initcore.InitRemoteChunkManager(paramtable.Get()) @@ -678,7 +678,7 @@ func (suite *SegmentLoaderDetailSuite) SetupTest() { } ctx := context.Background() - chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) suite.loader = NewLoader(suite.manager, suite.chunkManager) initcore.InitRemoteChunkManager(paramtable.Get()) @@ -847,7 +847,7 @@ func (suite *SegmentLoaderV2Suite) SetupTest() { // TODO:: cpp chunk manager not support local chunk manager // suite.chunkManager = storage.NewLocalChunkManager(storage.RootPath( // fmt.Sprintf("/tmp/milvus-ut/%d", rand.Int63()))) - chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) suite.loader = NewLoaderV2(suite.manager, suite.chunkManager) initcore.InitRemoteChunkManager(paramtable.Get()) diff --git a/internal/querynodev2/segments/segment_test.go b/internal/querynodev2/segments/segment_test.go index c03a902857..31271388af 100644 --- a/internal/querynodev2/segments/segment_test.go +++ b/internal/querynodev2/segments/segment_test.go @@ -39,7 +39,7 @@ func (suite *SegmentSuite) SetupTest() { msgLength := 100 suite.rootPath = suite.T().Name() - chunkManagerFactory := NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + chunkManagerFactory := storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.chunkManager, _ = chunkManagerFactory.NewPersistentStorageChunkManager(ctx) initcore.InitRemoteChunkManager(paramtable.Get()) diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index d803a6ce4e..7755911637 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -254,6 +254,7 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm node.factory, channel.GetSeekPosition().GetTimestamp(), node.queryHook, + node.chunkManager, ) if err != nil { log.Warn("failed to create shard delegator", zap.Error(err)) diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 8e48598123..a081ca4783 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -116,7 +116,7 @@ func (suite *ServiceSuite) SetupTest() { suite.msgStream = msgstream.NewMockMsgStream(suite.T()) // TODO:: cpp chunk manager not support local chunk manager // suite.chunkManagerFactory = storage.NewChunkManagerFactory("local", storage.RootPath("/tmp/milvus-test")) - suite.chunkManagerFactory = segments.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) + suite.chunkManagerFactory = storage.NewTestChunkManagerFactory(paramtable.Get(), suite.rootPath) suite.factory.EXPECT().Init(mock.Anything).Return() suite.factory.EXPECT().NewPersistentStorageChunkManager(mock.Anything).Return(suite.chunkManagerFactory.NewPersistentStorageChunkManager(ctx)) diff --git a/internal/storage/partition_stats.go b/internal/storage/partition_stats.go index 6f55675e1d..15173e4457 100644 --- a/internal/storage/partition_stats.go +++ b/internal/storage/partition_stats.go @@ -20,6 +20,14 @@ import "encoding/json" type SegmentStats struct { FieldStats []FieldStats `json:"fieldStats"` + NumRows int +} + +func NewSegmentStats(fieldStats []FieldStats, rows int) *SegmentStats { + return &SegmentStats{ + FieldStats: fieldStats, + NumRows: rows, + } } type PartitionStatsSnapshot struct { diff --git a/internal/storage/utils.go b/internal/storage/utils.go index 9c91b9cbcc..a5cde25050 100644 --- a/internal/storage/utils.go +++ b/internal/storage/utils.go @@ -38,6 +38,7 @@ import ( "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/mq/msgstream" "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" "github.com/milvus-io/milvus/pkg/util/typeutil" ) @@ -1247,3 +1248,17 @@ func Min(a, b int64) int64 { } return b } + +func NewTestChunkManagerFactory(params *paramtable.ComponentParam, rootPath string) *ChunkManagerFactory { + return NewChunkManagerFactory("minio", + RootPath(rootPath), + Address(params.MinioCfg.Address.GetValue()), + AccessKeyID(params.MinioCfg.AccessKeyID.GetValue()), + SecretAccessKeyID(params.MinioCfg.SecretAccessKey.GetValue()), + UseSSL(params.MinioCfg.UseSSL.GetAsBool()), + BucketName(params.MinioCfg.BucketName.GetValue()), + UseIAM(params.MinioCfg.UseIAM.GetAsBool()), + CloudProvider(params.MinioCfg.CloudProvider.GetValue()), + IAMEndpoint(params.MinioCfg.IAMEndpoint.GetValue()), + CreateBucket(true)) +} diff --git a/internal/util/clustering/clustering.go b/internal/util/clustering/clustering.go new file mode 100644 index 0000000000..c8b290f185 --- /dev/null +++ b/internal/util/clustering/clustering.go @@ -0,0 +1,50 @@ +package clustering + +import ( + "encoding/binary" + "math" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/util/distance" + "github.com/milvus-io/milvus/pkg/util/merr" +) + +func CalcVectorDistance(dim int64, dataType schemapb.DataType, left []byte, right []float32, metric string) ([]float32, error) { + switch dataType { + case schemapb.DataType_FloatVector: + distance, err := distance.CalcFloatDistance(dim, DeserializeFloatVector(left), right, metric) + if err != nil { + return nil, err + } + return distance, nil + // todo support other vector type + case schemapb.DataType_BinaryVector: + case schemapb.DataType_Float16Vector: + case schemapb.DataType_BFloat16Vector: + default: + return nil, merr.ErrParameterInvalid + } + return nil, nil +} + +func DeserializeFloatVector(data []byte) []float32 { + vectorLen := len(data) / 4 // Each float32 occupies 4 bytes + fv := make([]float32, vectorLen) + + for i := 0; i < vectorLen; i++ { + bits := binary.LittleEndian.Uint32(data[i*4 : (i+1)*4]) + fv[i] = math.Float32frombits(bits) + } + + return fv +} + +func SerializeFloatVector(fv []float32) []byte { + data := make([]byte, 0, 4*len(fv)) // float32 occupies 4 bytes + buf := make([]byte, 4) + for _, f := range fv { + binary.LittleEndian.PutUint32(buf, math.Float32bits(f)) + data = append(data, buf...) + } + return data +} diff --git a/internal/util/exprutil/expr_checker.go b/internal/util/exprutil/expr_checker.go new file mode 100644 index 0000000000..00866a9aa9 --- /dev/null +++ b/internal/util/exprutil/expr_checker.go @@ -0,0 +1,511 @@ +package exprutil + +import ( + "math" + "strings" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/proto/planpb" +) + +type KeyType int64 + +const ( + PartitionKey KeyType = iota + ClusteringKey KeyType = PartitionKey + 1 +) + +func ParseExprFromPlan(plan *planpb.PlanNode) (*planpb.Expr, error) { + node := plan.GetNode() + + if node == nil { + return nil, errors.New("can't get expr from empty plan node") + } + + var expr *planpb.Expr + switch node := node.(type) { + case *planpb.PlanNode_VectorAnns: + expr = node.VectorAnns.GetPredicates() + case *planpb.PlanNode_Query: + expr = node.Query.GetPredicates() + default: + return nil, errors.New("unsupported plan node type") + } + + return expr, nil +} + +func ParsePartitionKeysFromBinaryExpr(expr *planpb.BinaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { + leftRes, leftInRange := ParseKeysFromExpr(expr.Left, keyType) + rightRes, rightInRange := ParseKeysFromExpr(expr.Right, keyType) + + if expr.Op == planpb.BinaryExpr_LogicalAnd { + // case: partition_key_field in [7, 8] && partition_key > 8 + if len(leftRes)+len(rightRes) > 0 { + leftRes = append(leftRes, rightRes...) + return leftRes, false + } + + // case: other_field > 10 && partition_key_field > 8 + return nil, leftInRange || rightInRange + } + + if expr.Op == planpb.BinaryExpr_LogicalOr { + // case: partition_key_field in [7, 8] or partition_key > 8 + if leftInRange || rightInRange { + return nil, true + } + + // case: partition_key_field in [7, 8] or other_field > 10 + leftRes = append(leftRes, rightRes...) + return leftRes, false + } + + return nil, false +} + +func ParsePartitionKeysFromUnaryExpr(expr *planpb.UnaryExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { + res, partitionInRange := ParseKeysFromExpr(expr.GetChild(), keyType) + if expr.Op == planpb.UnaryExpr_Not { + // case: partition_key_field not in [7, 8] + if len(res) != 0 { + return nil, true + } + + // case: other_field not in [10] + return nil, partitionInRange + } + + // UnaryOp only includes "Not" for now + return res, partitionInRange +} + +func ParsePartitionKeysFromTermExpr(expr *planpb.TermExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { + if keyType == PartitionKey && expr.GetColumnInfo().GetIsPartitionKey() { + return expr.GetValues(), false + } else if keyType == ClusteringKey && expr.GetColumnInfo().GetIsClusteringKey() { + return expr.GetValues(), false + } + return nil, false +} + +func ParsePartitionKeysFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, keyType KeyType) ([]*planpb.GenericValue, bool) { + if expr.GetOp() == planpb.OpType_Equal { + if expr.GetColumnInfo().GetIsPartitionKey() && keyType == PartitionKey || + expr.GetColumnInfo().GetIsClusteringKey() && keyType == ClusteringKey { + return []*planpb.GenericValue{expr.Value}, false + } + } + return nil, true +} + +func ParseKeysFromExpr(expr *planpb.Expr, keyType KeyType) ([]*planpb.GenericValue, bool) { + var res []*planpb.GenericValue + keyInRange := false + switch expr := expr.GetExpr().(type) { + case *planpb.Expr_BinaryExpr: + res, keyInRange = ParsePartitionKeysFromBinaryExpr(expr.BinaryExpr, keyType) + case *planpb.Expr_UnaryExpr: + res, keyInRange = ParsePartitionKeysFromUnaryExpr(expr.UnaryExpr, keyType) + case *planpb.Expr_TermExpr: + res, keyInRange = ParsePartitionKeysFromTermExpr(expr.TermExpr, keyType) + case *planpb.Expr_UnaryRangeExpr: + res, keyInRange = ParsePartitionKeysFromUnaryRangeExpr(expr.UnaryRangeExpr, keyType) + } + + return res, keyInRange +} + +func ParseKeys(expr *planpb.Expr, kType KeyType) []*planpb.GenericValue { + res, keyInRange := ParseKeysFromExpr(expr, kType) + if keyInRange { + res = nil + } + + return res +} + +type PlanRange struct { + lower *planpb.GenericValue + upper *planpb.GenericValue + includeLower bool + includeUpper bool +} + +func (planRange *PlanRange) ToIntRange() *IntRange { + iRange := &IntRange{} + if planRange.lower == nil { + iRange.lower = math.MinInt64 + iRange.includeLower = false + } else { + iRange.lower = planRange.lower.GetInt64Val() + iRange.includeLower = planRange.includeLower + } + + if planRange.upper == nil { + iRange.upper = math.MaxInt64 + iRange.includeUpper = false + } else { + iRange.upper = planRange.upper.GetInt64Val() + iRange.includeUpper = planRange.includeUpper + } + return iRange +} + +func (planRange *PlanRange) ToStrRange() *StrRange { + sRange := &StrRange{} + if planRange.lower == nil { + sRange.lower = "" + sRange.includeLower = false + } else { + sRange.lower = planRange.lower.GetStringVal() + sRange.includeLower = planRange.includeLower + } + + if planRange.upper == nil { + sRange.upper = "" + sRange.includeUpper = false + } else { + sRange.upper = planRange.upper.GetStringVal() + sRange.includeUpper = planRange.includeUpper + } + return sRange +} + +type IntRange struct { + lower int64 + upper int64 + includeLower bool + includeUpper bool +} + +func NewIntRange(l int64, r int64, includeL bool, includeR bool) *IntRange { + return &IntRange{ + lower: l, + upper: r, + includeLower: includeL, + includeUpper: includeR, + } +} + +func IntRangeOverlap(range1 *IntRange, range2 *IntRange) bool { + var leftBound int64 + if range1.lower < range2.lower { + leftBound = range2.lower + } else { + leftBound = range1.lower + } + var rightBound int64 + if range1.upper < range2.upper { + rightBound = range1.upper + } else { + rightBound = range2.upper + } + return leftBound <= rightBound +} + +type StrRange struct { + lower string + upper string + includeLower bool + includeUpper bool +} + +func NewStrRange(l string, r string, includeL bool, includeR bool) *StrRange { + return &StrRange{ + lower: l, + upper: r, + includeLower: includeL, + includeUpper: includeR, + } +} + +func StrRangeOverlap(range1 *StrRange, range2 *StrRange) bool { + var leftBound string + if range1.lower < range2.lower { + leftBound = range2.lower + } else { + leftBound = range1.lower + } + var rightBound string + if range1.upper < range2.upper || range2.upper == "" { + rightBound = range1.upper + } else { + rightBound = range2.upper + } + return leftBound <= rightBound +} + +/* +principles for range parsing +1. no handling unary expr like 'NOT' +2. no handling 'or' expr, no matter on clusteringKey or not, just terminate all possible prune +3. for any unlogical 'and' expr, we check and terminate upper away +4. no handling Term and Range at the same time +*/ + +func ParseRanges(expr *planpb.Expr, kType KeyType) ([]*PlanRange, bool) { + var res []*PlanRange + matchALL := true + switch expr := expr.GetExpr().(type) { + case *planpb.Expr_BinaryExpr: + res, matchALL = ParseRangesFromBinaryExpr(expr.BinaryExpr, kType) + case *planpb.Expr_UnaryRangeExpr: + res, matchALL = ParseRangesFromUnaryRangeExpr(expr.UnaryRangeExpr, kType) + case *planpb.Expr_TermExpr: + res, matchALL = ParseRangesFromTermExpr(expr.TermExpr, kType) + case *planpb.Expr_UnaryExpr: + res, matchALL = nil, true + // we don't handle NOT operation, just consider as unable_to_parse_range + } + return res, matchALL +} + +func ParseRangesFromBinaryExpr(expr *planpb.BinaryExpr, kType KeyType) ([]*PlanRange, bool) { + if expr.Op == planpb.BinaryExpr_LogicalOr { + return nil, true + } + _, leftIsTerm := expr.GetLeft().GetExpr().(*planpb.Expr_TermExpr) + _, rightIsTerm := expr.GetRight().GetExpr().(*planpb.Expr_TermExpr) + if leftIsTerm || rightIsTerm { + // either of lower or upper is term query like x IN [1,2,3] + // we will terminate the prune process + return nil, true + } + leftRanges, leftALL := ParseRanges(expr.Left, kType) + rightRanges, rightALL := ParseRanges(expr.Right, kType) + if leftALL && rightALL { + return nil, true + } else if leftALL && !rightALL { + return rightRanges, rightALL + } else if rightALL && !leftALL { + return leftRanges, leftALL + } + // only unary ranges or further binary ranges are lower + // calculate the intersection and return the resulting ranges + // it's expected that only single range can be returned from lower and upper child + if len(leftRanges) != 1 || len(rightRanges) != 1 { + return nil, true + } + intersected := Intersect(leftRanges[0], rightRanges[0]) + matchALL := intersected == nil + return []*PlanRange{intersected}, matchALL +} + +func ParseRangesFromUnaryRangeExpr(expr *planpb.UnaryRangeExpr, kType KeyType) ([]*PlanRange, bool) { + if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey || + expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey { + switch expr.GetOp() { + case planpb.OpType_Equal: + { + return []*PlanRange{ + { + lower: expr.Value, + upper: expr.Value, + includeLower: true, + includeUpper: true, + }, + }, false + } + case planpb.OpType_GreaterThan: + { + return []*PlanRange{ + { + lower: expr.Value, + upper: nil, + includeLower: false, + includeUpper: false, + }, + }, false + } + case planpb.OpType_GreaterEqual: + { + return []*PlanRange{ + { + lower: expr.Value, + upper: nil, + includeLower: true, + includeUpper: false, + }, + }, false + } + case planpb.OpType_LessThan: + { + return []*PlanRange{ + { + lower: nil, + upper: expr.Value, + includeLower: false, + includeUpper: false, + }, + }, false + } + case planpb.OpType_LessEqual: + { + return []*PlanRange{ + { + lower: nil, + upper: expr.Value, + includeLower: false, + includeUpper: true, + }, + }, false + } + } + } + return nil, true +} + +func ParseRangesFromTermExpr(expr *planpb.TermExpr, kType KeyType) ([]*PlanRange, bool) { + if expr.GetColumnInfo().GetIsPartitionKey() && kType == PartitionKey || + expr.GetColumnInfo().GetIsClusteringKey() && kType == ClusteringKey { + res := make([]*PlanRange, 0) + for _, value := range expr.GetValues() { + res = append(res, &PlanRange{ + lower: value, + upper: value, + includeLower: true, + includeUpper: true, + }) + } + return res, false + } + return nil, true +} + +var minusInfiniteInt = &planpb.GenericValue{ + Val: &planpb.GenericValue_Int64Val{ + Int64Val: math.MinInt64, + }, +} + +var positiveInfiniteInt = &planpb.GenericValue{ + Val: &planpb.GenericValue_Int64Val{ + Int64Val: math.MaxInt64, + }, +} + +var minStrVal = &planpb.GenericValue{ + Val: &planpb.GenericValue_StringVal{ + StringVal: "", + }, +} + +var maxStrVal = &planpb.GenericValue{} + +func complementPlanRange(pr *PlanRange, dataType schemapb.DataType) *PlanRange { + if dataType == schemapb.DataType_Int64 { + if pr.lower == nil { + pr.lower = minusInfiniteInt + } + if pr.upper == nil { + pr.upper = positiveInfiniteInt + } + } else { + if pr.lower == nil { + pr.lower = minStrVal + } + if pr.upper == nil { + pr.upper = maxStrVal + } + } + + return pr +} + +func GetCommonDataType(a *PlanRange, b *PlanRange) schemapb.DataType { + var bound *planpb.GenericValue + if a.lower != nil { + bound = a.lower + } else if a.upper != nil { + bound = a.upper + } + if bound == nil { + if b.lower != nil { + bound = b.lower + } else if b.upper != nil { + bound = b.upper + } + } + if bound == nil { + return schemapb.DataType_None + } + switch bound.Val.(type) { + case *planpb.GenericValue_Int64Val: + { + return schemapb.DataType_Int64 + } + case *planpb.GenericValue_StringVal: + { + return schemapb.DataType_VarChar + } + } + return schemapb.DataType_None +} + +func Intersect(a *PlanRange, b *PlanRange) *PlanRange { + dataType := GetCommonDataType(a, b) + complementPlanRange(a, dataType) + complementPlanRange(b, dataType) + + // Check if 'a' and 'b' non-overlapping at all + rightBound := minGenericValue(a.upper, b.upper) + leftBound := maxGenericValue(a.lower, b.lower) + if compareGenericValue(leftBound, rightBound) > 0 { + return nil + } + + // Check if 'a' range ends exactly where 'b' range starts + if !a.includeUpper && !b.includeLower && (compareGenericValue(a.upper, b.lower) == 0) { + return nil + } + // Check if 'b' range ends exactly where 'a' range starts + if !b.includeUpper && !a.includeLower && (compareGenericValue(b.upper, a.lower) == 0) { + return nil + } + + return &PlanRange{ + lower: leftBound, + upper: rightBound, + includeLower: a.includeLower || b.includeLower, + includeUpper: a.includeUpper || b.includeUpper, + } +} + +func compareGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) int64 { + if right == nil || left == nil { + return -1 + } + switch left.Val.(type) { + case *planpb.GenericValue_Int64Val: + if left.GetInt64Val() == right.GetInt64Val() { + return 0 + } else if left.GetInt64Val() < right.GetInt64Val() { + return -1 + } else { + return 1 + } + case *planpb.GenericValue_StringVal: + if right.Val == nil { + return -1 + } + return int64(strings.Compare(left.GetStringVal(), right.GetStringVal())) + } + return 0 +} + +func minGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue { + if compareGenericValue(left, right) < 0 { + return left + } + return right +} + +func maxGenericValue(left *planpb.GenericValue, right *planpb.GenericValue) *planpb.GenericValue { + if compareGenericValue(left, right) >= 0 { + return left + } + return right +} diff --git a/internal/util/exprutil/expr_checker_test.go b/internal/util/exprutil/expr_checker_test.go new file mode 100644 index 0000000000..dc519331eb --- /dev/null +++ b/internal/util/exprutil/expr_checker_test.go @@ -0,0 +1,279 @@ +package exprutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/proto/planpb" + "github.com/milvus-io/milvus/internal/util/testutil" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/util/funcutil" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +func TestParsePartitionKeys(t *testing.T) { + prefix := "TestParsePartitionKeys" + collectionName := prefix + funcutil.GenRandomStr() + + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["int64_field"] = schemapb.DataType_Int64 + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector + schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type, + "int64_field", false, 8) + partitionKeyField := &schemapb.FieldSchema{ + Name: "partition_key_field", + DataType: schemapb.DataType_Int64, + IsPartitionKey: true, + } + schema.Fields = append(schema.Fields, partitionKeyField) + + fieldID := common.StartOfUserFieldID + for _, field := range schema.Fields { + field.FieldID = int64(fieldID) + fieldID++ + } + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + + queryInfo := &planpb.QueryInfo{ + Topk: 10, + MetricType: "L2", + SearchParams: "", + RoundDecimal: -1, + } + + type testCase struct { + name string + expr string + expected int + validPartitionKeys []int64 + invalidPartitionKeys []int64 + } + cases := []testCase{ + { + name: "binary_expr_and with term", + expr: "partition_key_field in [7, 8] && int64_field >= 10", + expected: 2, + validPartitionKeys: []int64{7, 8}, + invalidPartitionKeys: []int64{}, + }, + { + name: "binary_expr_and with equal", + expr: "partition_key_field == 7 && int64_field >= 10", + expected: 1, + validPartitionKeys: []int64{7}, + invalidPartitionKeys: []int64{}, + }, + { + name: "binary_expr_and with term2", + expr: "partition_key_field in [7, 8] && int64_field == 10", + expected: 2, + validPartitionKeys: []int64{7, 8}, + invalidPartitionKeys: []int64{10}, + }, + { + name: "binary_expr_and with partition key in range", + expr: "partition_key_field in [7, 8] && partition_key_field > 9", + expected: 2, + validPartitionKeys: []int64{7, 8}, + invalidPartitionKeys: []int64{9}, + }, + { + name: "binary_expr_and with partition key in range2", + expr: "int64_field == 10 && partition_key_field > 9", + expected: 0, + validPartitionKeys: []int64{}, + invalidPartitionKeys: []int64{}, + }, + { + name: "binary_expr_and with term and not", + expr: "partition_key_field in [7, 8] && partition_key_field not in [10, 20]", + expected: 2, + validPartitionKeys: []int64{7, 8}, + invalidPartitionKeys: []int64{10, 20}, + }, + { + name: "binary_expr_or with term and not", + expr: "partition_key_field in [7, 8] or partition_key_field not in [10, 20]", + expected: 0, + validPartitionKeys: []int64{}, + invalidPartitionKeys: []int64{}, + }, + { + name: "binary_expr_or with term and not 2", + expr: "partition_key_field in [7, 8] or int64_field not in [10, 20]", + expected: 2, + validPartitionKeys: []int64{7, 8}, + invalidPartitionKeys: []int64{10, 20}, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // test search plan + searchPlan, err := planparserv2.CreateSearchPlan(schemaHelper, tc.expr, "fvec_field", queryInfo) + assert.NoError(t, err) + expr, err := ParseExprFromPlan(searchPlan) + assert.NoError(t, err) + partitionKeys := ParseKeys(expr, PartitionKey) + assert.Equal(t, tc.expected, len(partitionKeys)) + for _, key := range partitionKeys { + int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val + assert.Contains(t, tc.validPartitionKeys, int64Val) + assert.NotContains(t, tc.invalidPartitionKeys, int64Val) + } + + // test query plan + queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, tc.expr) + assert.NoError(t, err) + expr, err = ParseExprFromPlan(queryPlan) + assert.NoError(t, err) + partitionKeys = ParseKeys(expr, PartitionKey) + assert.Equal(t, tc.expected, len(partitionKeys)) + for _, key := range partitionKeys { + int64Val := key.Val.(*planpb.GenericValue_Int64Val).Int64Val + assert.Contains(t, tc.validPartitionKeys, int64Val) + assert.NotContains(t, tc.invalidPartitionKeys, int64Val) + } + }) + } +} + +func TestParseIntRanges(t *testing.T) { + prefix := "TestParseRanges" + clusterKeyField := "cluster_key_field" + collectionName := prefix + funcutil.GenRandomStr() + + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["int64_field"] = schemapb.DataType_Int64 + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector + schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type, + "int64_field", false, 8) + clusterKeyFieldSchema := &schemapb.FieldSchema{ + Name: clusterKeyField, + DataType: schemapb.DataType_Int64, + IsClusteringKey: true, + } + schema.Fields = append(schema.Fields, clusterKeyFieldSchema) + + fieldID := common.StartOfUserFieldID + for _, field := range schema.Fields { + field.FieldID = int64(fieldID) + fieldID++ + } + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + // test query plan + { + expr := "cluster_key_field > 50" + queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr) + assert.NoError(t, err) + planExpr, err := ParseExprFromPlan(queryPlan) + assert.NoError(t, err) + parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) + assert.False(t, matchALL) + assert.Equal(t, 1, len(parsedRanges)) + range0 := parsedRanges[0] + assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50)) + assert.Nil(t, range0.upper) + assert.Equal(t, range0.includeLower, false) + assert.Equal(t, range0.includeUpper, false) + } + + // test binary query plan + { + expr := "cluster_key_field > 50 and cluster_key_field <= 100" + queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr) + assert.NoError(t, err) + planExpr, err := ParseExprFromPlan(queryPlan) + assert.NoError(t, err) + parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) + assert.False(t, matchALL) + assert.Equal(t, 1, len(parsedRanges)) + range0 := parsedRanges[0] + assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50)) + assert.Equal(t, false, range0.includeLower) + assert.Equal(t, true, range0.includeUpper) + } + + // test binary query plan + { + expr := "cluster_key_field >= 50 and cluster_key_field < 100" + queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr) + assert.NoError(t, err) + planExpr, err := ParseExprFromPlan(queryPlan) + assert.NoError(t, err) + parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) + assert.False(t, matchALL) + assert.Equal(t, 1, len(parsedRanges)) + range0 := parsedRanges[0] + assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(50)) + assert.Equal(t, true, range0.includeLower) + assert.Equal(t, false, range0.includeUpper) + } + + // test binary query plan + { + expr := "cluster_key_field in [100]" + queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr) + assert.NoError(t, err) + planExpr, err := ParseExprFromPlan(queryPlan) + assert.NoError(t, err) + parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) + assert.False(t, matchALL) + assert.Equal(t, 1, len(parsedRanges)) + range0 := parsedRanges[0] + assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_Int64Val).Int64Val, int64(100)) + assert.Equal(t, true, range0.includeLower) + assert.Equal(t, true, range0.includeUpper) + } +} + +func TestParseStrRanges(t *testing.T) { + prefix := "TestParseRanges" + clusterKeyField := "cluster_key_field" + collectionName := prefix + funcutil.GenRandomStr() + + fieldName2Type := make(map[string]schemapb.DataType) + fieldName2Type["int64_field"] = schemapb.DataType_Int64 + fieldName2Type["varChar_field"] = schemapb.DataType_VarChar + fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector + schema := testutil.ConstructCollectionSchemaByDataType(collectionName, fieldName2Type, + "int64_field", false, 8) + clusterKeyFieldSchema := &schemapb.FieldSchema{ + Name: clusterKeyField, + DataType: schemapb.DataType_VarChar, + IsClusteringKey: true, + } + schema.Fields = append(schema.Fields, clusterKeyFieldSchema) + + fieldID := common.StartOfUserFieldID + for _, field := range schema.Fields { + field.FieldID = int64(fieldID) + fieldID++ + } + schemaHelper, err := typeutil.CreateSchemaHelper(schema) + require.NoError(t, err) + // test query plan + { + expr := "cluster_key_field >= \"aaa\"" + queryPlan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr) + assert.NoError(t, err) + planExpr, err := ParseExprFromPlan(queryPlan) + assert.NoError(t, err) + parsedRanges, matchALL := ParseRanges(planExpr, ClusteringKey) + assert.False(t, matchALL) + assert.Equal(t, 1, len(parsedRanges)) + range0 := parsedRanges[0] + assert.Equal(t, range0.lower.Val.(*planpb.GenericValue_StringVal).StringVal, "aaa") + assert.Nil(t, range0.upper) + assert.Equal(t, range0.includeLower, true) + assert.Equal(t, range0.includeUpper, false) + } +} diff --git a/internal/util/testutil/test_util.go b/internal/util/testutil/test_util.go new file mode 100644 index 0000000000..8fb7856b83 --- /dev/null +++ b/internal/util/testutil/test_util.go @@ -0,0 +1,90 @@ +package testutil + +import ( + "strconv" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/common" +) + +const ( + testMaxVarCharLength = 100 +) + +func ConstructCollectionSchemaWithKeys(collectionName string, + fieldName2DataType map[string]schemapb.DataType, + primaryFieldName string, + partitionKeyFieldName string, + clusteringKeyFieldName string, + autoID bool, + dim int, +) *schemapb.CollectionSchema { + schema := ConstructCollectionSchemaByDataType(collectionName, + fieldName2DataType, + primaryFieldName, + autoID, + dim) + for _, field := range schema.Fields { + if field.Name == partitionKeyFieldName { + field.IsPartitionKey = true + } + if field.Name == clusteringKeyFieldName { + field.IsClusteringKey = true + } + } + + return schema +} + +func isVectorType(dataType schemapb.DataType) bool { + return dataType == schemapb.DataType_FloatVector || + dataType == schemapb.DataType_BinaryVector || + dataType == schemapb.DataType_Float16Vector || + dataType == schemapb.DataType_BFloat16Vector +} + +func ConstructCollectionSchemaByDataType(collectionName string, + fieldName2DataType map[string]schemapb.DataType, + primaryFieldName string, + autoID bool, + dim int, +) *schemapb.CollectionSchema { + fieldsSchema := make([]*schemapb.FieldSchema, 0) + fieldIdx := int64(0) + for fieldName, dataType := range fieldName2DataType { + fieldSchema := &schemapb.FieldSchema{ + Name: fieldName, + DataType: dataType, + FieldID: fieldIdx, + } + fieldIdx += 1 + if isVectorType(dataType) { + fieldSchema.TypeParams = []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: strconv.Itoa(dim), + }, + } + } + if dataType == schemapb.DataType_VarChar { + fieldSchema.TypeParams = []*commonpb.KeyValuePair{ + { + Key: common.MaxLengthKey, + Value: strconv.Itoa(testMaxVarCharLength), + }, + } + } + if fieldName == primaryFieldName { + fieldSchema.IsPrimaryKey = true + fieldSchema.AutoID = autoID + } + + fieldsSchema = append(fieldsSchema, fieldSchema) + } + + return &schemapb.CollectionSchema{ + Name: collectionName, + Fields: fieldsSchema, + } +} diff --git a/internal/util/typeutil/schema.go b/internal/util/typeutil/schema.go index 77d4159b83..3140b140b0 100644 --- a/internal/util/typeutil/schema.go +++ b/internal/util/typeutil/schema.go @@ -119,3 +119,12 @@ func convertToArrowType(dataType schemapb.DataType) (arrow.DataType, error) { return nil, merr.WrapErrParameterInvalidMsg("unknown type %v", dataType.String()) } } + +func GetClusteringKeyField(fields []*schemapb.FieldSchema) *schemapb.FieldSchema { + for _, field := range fields { + if field.IsClusteringKey { + return field + } + } + return nil +} diff --git a/pkg/common/common.go b/pkg/common/common.go index 6c54347ccd..aca0621fed 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -88,6 +88,9 @@ const ( // SegmentIndexPath storage path const for segment index files. SegmentIndexPath = `index_files` + + // PartitionStatsPath storage path const for partition stats files + PartitionStatsPath = `part_stats` ) // Search, Index parameter keys diff --git a/pkg/util/paramtable/component_param.go b/pkg/util/paramtable/component_param.go index ac5a6c3a82..5e0ce1e8a0 100644 --- a/pkg/util/paramtable/component_param.go +++ b/pkg/util/paramtable/component_param.go @@ -2039,6 +2039,7 @@ type queryNodeConfig struct { FlowGraphMaxParallelism ParamItem `refreshable:"false"` MemoryIndexLoadPredictMemoryUsageFactor ParamItem `refreshable:"true"` + EnableSegmentPrune ParamItem `refreshable:"false"` } func (p *queryNodeConfig) init(base *BaseTable) { @@ -2512,6 +2513,13 @@ Max read concurrency must greater than or equal to 1, and less than or equal to Doc: "memory usage prediction factor for memory index loaded", } p.MemoryIndexLoadPredictMemoryUsageFactor.Init(base.mgr) + p.EnableSegmentPrune = ParamItem{ + Key: "queryNode.enableSegmentPrune", + Version: "2.3.4", + DefaultValue: "false", + Doc: "use partition prune function on shard delegator", + } + p.EnableSegmentPrune.Init(base.mgr) } // /////////////////////////////////////////////////////////////////////////////