diff --git a/internal/querynodev2/delegator/ScalarPruner.go b/internal/querynodev2/delegator/ScalarPruner.go index 2d13ea496d..c2fe95fb83 100644 --- a/internal/querynodev2/delegator/ScalarPruner.go +++ b/internal/querynodev2/delegator/ScalarPruner.go @@ -203,62 +203,81 @@ func NewParseContext(keyField FieldID, dType schemapb.DataType) *ParseContext { return &ParseContext{keyField, dType} } -func ParseExpr(exprPb *planpb.Expr, parseCtx *ParseContext) Expr { +func ParseExpr(exprPb *planpb.Expr, parseCtx *ParseContext) (Expr, error) { var res Expr + var err error switch exp := exprPb.GetExpr().(type) { case *planpb.Expr_BinaryExpr: - res = ParseLogicalBinaryExpr(exp.BinaryExpr, parseCtx) + res, err = ParseLogicalBinaryExpr(exp.BinaryExpr, parseCtx) case *planpb.Expr_UnaryExpr: - res = ParseLogicalUnaryExpr(exp.UnaryExpr, parseCtx) + res, err = ParseLogicalUnaryExpr(exp.UnaryExpr, parseCtx) case *planpb.Expr_BinaryRangeExpr: - res = ParseBinaryRangeExpr(exp.BinaryRangeExpr, parseCtx) + res, err = ParseBinaryRangeExpr(exp.BinaryRangeExpr, parseCtx) case *planpb.Expr_UnaryRangeExpr: - res = ParseUnaryRangeExpr(exp.UnaryRangeExpr, parseCtx) + res, err = ParseUnaryRangeExpr(exp.UnaryRangeExpr, parseCtx) case *planpb.Expr_TermExpr: - res = ParseTermExpr(exp.TermExpr, parseCtx) + res, err = ParseTermExpr(exp.TermExpr, parseCtx) } - return res + return res, err } -func ParseLogicalBinaryExpr(exprPb *planpb.BinaryExpr, parseCtx *ParseContext) Expr { - leftExpr := ParseExpr(exprPb.Left, parseCtx) - rightExpr := ParseExpr(exprPb.Right, parseCtx) - return NewLogicalBinaryExpr(leftExpr, rightExpr, exprPb.GetOp()) +func ParseLogicalBinaryExpr(exprPb *planpb.BinaryExpr, parseCtx *ParseContext) (Expr, error) { + leftExpr, err := ParseExpr(exprPb.Left, parseCtx) + if err != nil { + return nil, err + } + rightExpr, err := ParseExpr(exprPb.Right, parseCtx) + if err != nil { + return nil, err + } + return NewLogicalBinaryExpr(leftExpr, rightExpr, exprPb.GetOp()), nil } -func ParseLogicalUnaryExpr(exprPb *planpb.UnaryExpr, parseCtx *ParseContext) Expr { +func ParseLogicalUnaryExpr(exprPb *planpb.UnaryExpr, parseCtx *ParseContext) (Expr, error) { // currently we don't handle NOT expr, this part of code is left for logical integrity - return nil + return nil, nil } -func ParseBinaryRangeExpr(exprPb *planpb.BinaryRangeExpr, parseCtx *ParseContext) Expr { +func ParseBinaryRangeExpr(exprPb *planpb.BinaryRangeExpr, parseCtx *ParseContext) (Expr, error) { if exprPb.GetColumnInfo().GetFieldId() != parseCtx.keyFieldIDToPrune { - return nil + return nil, nil } - lower := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetLowerValue()) - upper := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetUpperValue()) - return NewBinaryRangeExpr(lower, upper, exprPb.LowerInclusive, exprPb.UpperInclusive) + lower, err := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetLowerValue()) + if err != nil { + return nil, err + } + upper, err := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetUpperValue()) + if err != nil { + return nil, err + } + return NewBinaryRangeExpr(lower, upper, exprPb.LowerInclusive, exprPb.UpperInclusive), nil } -func ParseUnaryRangeExpr(exprPb *planpb.UnaryRangeExpr, parseCtx *ParseContext) Expr { +func ParseUnaryRangeExpr(exprPb *planpb.UnaryRangeExpr, parseCtx *ParseContext) (Expr, error) { if exprPb.GetColumnInfo().GetFieldId() != parseCtx.keyFieldIDToPrune { - return nil + return nil, nil } if exprPb.GetOp() == planpb.OpType_NotEqual { - return nil + return nil, nil // segment-prune based on min-max cannot support not equal semantic } - innerVal := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetValue()) - return NewUnaryRangeExpr(innerVal, exprPb.GetOp()) + innerVal, err := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, exprPb.GetValue()) + if err != nil { + return nil, err + } + return NewUnaryRangeExpr(innerVal, exprPb.GetOp()), nil } -func ParseTermExpr(exprPb *planpb.TermExpr, parseCtx *ParseContext) Expr { +func ParseTermExpr(exprPb *planpb.TermExpr, parseCtx *ParseContext) (Expr, error) { if exprPb.GetColumnInfo().GetFieldId() != parseCtx.keyFieldIDToPrune { - return nil + return nil, nil } scalarVals := make([]storage.ScalarFieldValue, 0) for _, val := range exprPb.GetValues() { - scalarVals = append(scalarVals, storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, val)) + innerVal, err := storage.NewScalarFieldValueFromGenericValue(parseCtx.dataType, val) + if err == nil { + scalarVals = append(scalarVals, innerVal) + } } - return NewTermExpr(scalarVals) + return NewTermExpr(scalarVals), nil } diff --git a/internal/querynodev2/delegator/segment_pruner.go b/internal/querynodev2/delegator/segment_pruner.go index d5b1116d39..9324226e81 100644 --- a/internal/querynodev2/delegator/segment_pruner.go +++ b/internal/querynodev2/delegator/segment_pruner.go @@ -102,7 +102,11 @@ func PruneSegments(ctx context.Context, } // 1. parse expr for prune - expr := ParseExpr(exprPb, NewParseContext(clusteringKeyField.GetFieldID(), clusteringKeyField.GetDataType())) + expr, err := ParseExpr(exprPb, NewParseContext(clusteringKeyField.GetFieldID(), clusteringKeyField.GetDataType())) + if err != nil { + log.Ctx(ctx).RatedWarn(10, "failed to parse expr for segment prune, fallback to common search/query", zap.Error(err)) + return + } // 2. prune segments by scalar field targetSegmentStats := make([]storage.SegmentStats, 0, 32) diff --git a/internal/querynodev2/delegator/segment_pruner_test.go b/internal/querynodev2/delegator/segment_pruner_test.go index 41bd7a96bb..cf925ec201 100644 --- a/internal/querynodev2/delegator/segment_pruner_test.go +++ b/internal/querynodev2/delegator/segment_pruner_test.go @@ -25,24 +25,19 @@ type SegmentPrunerSuite struct { collectionName string primaryFieldName string clusterKeyFieldName string - autoID bool targetPartition int64 dim int sealedSegments []SnapshotItem } -func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, - clusterKeyFieldType schemapb.DataType, -) { +func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string) { sps.collectionName = "test_segment_prune" sps.primaryFieldName = "pk" sps.clusterKeyFieldName = clusterKeyFieldName - sps.autoID = true sps.dim = 8 fieldName2DataType := make(map[string]schemapb.DataType) fieldName2DataType[sps.primaryFieldName] = schemapb.DataType_Int64 - fieldName2DataType[sps.clusterKeyFieldName] = clusterKeyFieldType fieldName2DataType["info"] = schemapb.DataType_VarChar fieldName2DataType["age"] = schemapb.DataType_Int64 fieldName2DataType["vec"] = schemapb.DataType_FloatVector @@ -88,8 +83,8 @@ func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, // into the same struct, in the real user cases, a field stat // can either contain min&&max or centroids segStats := make(map[UniqueID]storage.SegmentStats) - switch clusterKeyFieldType { - case schemapb.DataType_Int64, schemapb.DataType_Int32, schemapb.DataType_Int16, schemapb.DataType_Int8: + switch fieldName2DataType[sps.clusterKeyFieldName] { + case schemapb.DataType_Int64: { fieldStats := make([]storage.FieldStats, 0) fieldStat1 := storage.FieldStats{ @@ -144,8 +139,8 @@ func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, fieldStat1 := storage.FieldStats{ FieldID: clusteringKeyFieldID, Type: schemapb.DataType_VarChar, - Min: storage.NewStringFieldValue("ab"), - Max: storage.NewStringFieldValue("bbc"), + Min: storage.NewVarCharFieldValue("ab"), + Max: storage.NewVarCharFieldValue("bbc"), Centroids: centroids1, } fieldStats = append(fieldStats, fieldStat1) @@ -156,8 +151,8 @@ func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, fieldStat1 := storage.FieldStats{ FieldID: clusteringKeyFieldID, Type: schemapb.DataType_VarChar, - Min: storage.NewStringFieldValue("hhh"), - Max: storage.NewStringFieldValue("jjx"), + Min: storage.NewVarCharFieldValue("hhh"), + Max: storage.NewVarCharFieldValue("jjx"), Centroids: centroids2, } fieldStats = append(fieldStats, fieldStat1) @@ -168,8 +163,8 @@ func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, fieldStat1 := storage.FieldStats{ FieldID: clusteringKeyFieldID, Type: schemapb.DataType_VarChar, - Min: storage.NewStringFieldValue("kkk"), - Max: storage.NewStringFieldValue("lmn"), + Min: storage.NewVarCharFieldValue("kkk"), + Max: storage.NewVarCharFieldValue("lmn"), Centroids: centroids3, } fieldStats = append(fieldStats, fieldStat1) @@ -180,8 +175,8 @@ func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, fieldStat1 := storage.FieldStats{ FieldID: clusteringKeyFieldID, Type: schemapb.DataType_VarChar, - Min: storage.NewStringFieldValue("oo2"), - Max: storage.NewStringFieldValue("pptt"), + Min: storage.NewVarCharFieldValue("oo2"), + Max: storage.NewVarCharFieldValue("pptt"), Centroids: centroids4, } fieldStats = append(fieldStats, fieldStat1) @@ -227,7 +222,7 @@ func (sps *SegmentPrunerSuite) SetupForClustering(clusterKeyFieldName string, } func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() { - sps.SetupForClustering("age", schemapb.DataType_Int32) + sps.SetupForClustering("age") paramtable.Init() targetPartitions := make([]UniqueID, 0) targetPartitions = append(targetPartitions, sps.targetPartition) @@ -423,7 +418,7 @@ func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarIntField() { } func (sps *SegmentPrunerSuite) TestPruneSegmentsWithUnrelatedField() { - sps.SetupForClustering("age", schemapb.DataType_Int32) + sps.SetupForClustering("age") paramtable.Init() targetPartitions := make([]UniqueID, 0) targetPartitions = append(targetPartitions, sps.targetPartition) @@ -539,7 +534,7 @@ func (sps *SegmentPrunerSuite) TestPruneSegmentsWithUnrelatedField() { } func (sps *SegmentPrunerSuite) TestPruneSegmentsByScalarStrField() { - sps.SetupForClustering("info", schemapb.DataType_VarChar) + sps.SetupForClustering("info") paramtable.Init() targetPartitions := make([]UniqueID, 0) targetPartitions = append(targetPartitions, sps.targetPartition) @@ -618,7 +613,7 @@ func vector2Placeholder(vectors [][]float32) *commonpb.PlaceholderValue { func (sps *SegmentPrunerSuite) TestPruneSegmentsByVectorField() { paramtable.Init() paramtable.Get().Save(paramtable.Get().CommonCfg.EnableVectorClusteringKey.Key, "true") - sps.SetupForClustering("vec", schemapb.DataType_FloatVector) + sps.SetupForClustering("vec") vector1 := []float32{0.8877872002188053, 0.6131822285635065, 0.8476814632326242, 0.6645877829359371, 0.9962627712600025, 0.8976183052440327, 0.41941169325798844, 0.7554387854258499} vector2 := []float32{0.8644394874390322, 0.023327886647378615, 0.08330118483461302, 0.7068040179963112, 0.6983994910799851, 0.5562075958994153, 0.3288536247938002, 0.07077341010237759} vectors := [][]float32{vector1, vector2} @@ -644,6 +639,729 @@ func (sps *SegmentPrunerSuite) TestPruneSegmentsByVectorField() { sps.Equal(int64(3), sps.sealedSegments[1].Segments[0].SegmentID) } +func (sps *SegmentPrunerSuite) TestPruneSegmentsVariousIntTypes() { + paramtable.Init() + collectionName := "test_segment_prune" + primaryFieldName := "pk" + dim := 8 + var targetPartition int64 = 1 + const INT8 = "int8" + const INT16 = "int16" + const INT32 = "int32" + const INT64 = "int64" + const VEC = "vec" + + fieldName2DataType := make(map[string]schemapb.DataType) + fieldName2DataType[primaryFieldName] = schemapb.DataType_Int64 + fieldName2DataType[INT8] = schemapb.DataType_Int8 + fieldName2DataType[INT16] = schemapb.DataType_Int16 + fieldName2DataType[INT32] = schemapb.DataType_Int32 + fieldName2DataType[INT64] = schemapb.DataType_Int64 + fieldName2DataType[VEC] = schemapb.DataType_FloatVector + + { + // test for int8 cluster field + clusterFieldName := INT8 + schema := testutil.ConstructCollectionSchemaWithKeys(collectionName, + fieldName2DataType, + primaryFieldName, + "", + clusterFieldName, + false, + dim) + + var clusteringKeyFieldID int64 = 0 + for _, field := range schema.GetFields() { + if field.IsClusteringKey { + clusteringKeyFieldID = field.FieldID + break + } + } + + // set up part stats + segStats := make(map[UniqueID]storage.SegmentStats) + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int8, + Min: storage.NewInt8FieldValue(-127), + Max: storage.NewInt8FieldValue(-23), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[1] = *storage.NewSegmentStats(fieldStats, 80) + } + + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int8, + Min: storage.NewInt8FieldValue(-22), + Max: storage.NewInt8FieldValue(-8), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[2] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int8, + Min: storage.NewInt8FieldValue(-7), + Max: storage.NewInt8FieldValue(15), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[3] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int8, + Min: storage.NewInt8FieldValue(16), + Max: storage.NewInt8FieldValue(127), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[4] = *storage.NewSegmentStats(fieldStats, 80) + } + partitionStats := make(map[UniqueID]*storage.PartitionStatsSnapshot) + partitionStats[targetPartition] = &storage.PartitionStatsSnapshot{ + SegmentStats: segStats, + } + sealedSegments := make([]SnapshotItem, 0) + item1 := SnapshotItem{ + NodeID: 1, + Segments: []SegmentEntry{ + { + NodeID: 1, + SegmentID: 1, + }, + { + NodeID: 1, + SegmentID: 2, + }, + }, + } + item2 := SnapshotItem{ + NodeID: 2, + Segments: []SegmentEntry{ + { + NodeID: 2, + SegmentID: 3, + }, + { + NodeID: 2, + SegmentID: 4, + }, + }, + } + sealedSegments = append(sealedSegments, item1) + sealedSegments = append(sealedSegments, item2) + + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int8 > 128" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(2, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int8 < -129" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(2, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int8 > 50" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(0, len(sealedSegments[0].Segments)) + sps.Equal(1, len(sealedSegments[1].Segments)) + } + } + { + // test for int16 cluster field + clusterFieldName := INT16 + schema := testutil.ConstructCollectionSchemaWithKeys(collectionName, + fieldName2DataType, + primaryFieldName, + "", + clusterFieldName, + false, + dim) + + var clusteringKeyFieldID int64 = 0 + for _, field := range schema.GetFields() { + if field.IsClusteringKey { + clusteringKeyFieldID = field.FieldID + break + } + } + + // set up part stats + segStats := make(map[UniqueID]storage.SegmentStats) + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int16, + Min: storage.NewInt16FieldValue(-3127), + Max: storage.NewInt16FieldValue(-2123), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[1] = *storage.NewSegmentStats(fieldStats, 80) + } + + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int16, + Min: storage.NewInt16FieldValue(-2112), + Max: storage.NewInt16FieldValue(-1118), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[2] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int16, + Min: storage.NewInt16FieldValue(-17), + Max: storage.NewInt16FieldValue(3315), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[3] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int16, + Min: storage.NewInt16FieldValue(3415), + Max: storage.NewInt16FieldValue(4127), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[4] = *storage.NewSegmentStats(fieldStats, 80) + } + partitionStats := make(map[UniqueID]*storage.PartitionStatsSnapshot) + partitionStats[targetPartition] = &storage.PartitionStatsSnapshot{ + SegmentStats: segStats, + } + sealedSegments := make([]SnapshotItem, 0) + item1 := SnapshotItem{ + NodeID: 1, + Segments: []SegmentEntry{ + { + NodeID: 1, + SegmentID: 1, + }, + { + NodeID: 1, + SegmentID: 2, + }, + }, + } + item2 := SnapshotItem{ + NodeID: 2, + Segments: []SegmentEntry{ + { + NodeID: 2, + SegmentID: 3, + }, + { + NodeID: 2, + SegmentID: 4, + }, + }, + } + sealedSegments = append(sealedSegments, item1) + sealedSegments = append(sealedSegments, item2) + + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int16 > 32768" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(2, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int16 < -32769" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(2, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int16 > 2550" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(0, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + } + + { + // test for int32 cluster field + clusterFieldName := INT32 + schema := testutil.ConstructCollectionSchemaWithKeys(collectionName, + fieldName2DataType, + primaryFieldName, + "", + clusterFieldName, + false, + dim) + + var clusteringKeyFieldID int64 = 0 + for _, field := range schema.GetFields() { + if field.IsClusteringKey { + clusteringKeyFieldID = field.FieldID + break + } + } + + // set up part stats + segStats := make(map[UniqueID]storage.SegmentStats) + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int32, + Min: storage.NewInt32FieldValue(-13127), + Max: storage.NewInt32FieldValue(-12123), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[1] = *storage.NewSegmentStats(fieldStats, 80) + } + + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int32, + Min: storage.NewInt32FieldValue(-5127), + Max: storage.NewInt32FieldValue(-3123), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[2] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int32, + Min: storage.NewInt32FieldValue(-3121), + Max: storage.NewInt32FieldValue(-1123), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[3] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Int32, + Min: storage.NewInt32FieldValue(3121), + Max: storage.NewInt32FieldValue(41123), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[4] = *storage.NewSegmentStats(fieldStats, 80) + } + partitionStats := make(map[UniqueID]*storage.PartitionStatsSnapshot) + partitionStats[targetPartition] = &storage.PartitionStatsSnapshot{ + SegmentStats: segStats, + } + sealedSegments := make([]SnapshotItem, 0) + item1 := SnapshotItem{ + NodeID: 1, + Segments: []SegmentEntry{ + { + NodeID: 1, + SegmentID: 1, + }, + { + NodeID: 1, + SegmentID: 2, + }, + }, + } + item2 := SnapshotItem{ + NodeID: 2, + Segments: []SegmentEntry{ + { + NodeID: 2, + SegmentID: 3, + }, + { + NodeID: 2, + SegmentID: 4, + }, + }, + } + sealedSegments = append(sealedSegments, item1) + sealedSegments = append(sealedSegments, item2) + + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int32 > 2147483648" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(2, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int32 < -2147483649" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(2, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "int32 > 12550" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(0, len(sealedSegments[0].Segments)) + sps.Equal(1, len(sealedSegments[1].Segments)) + } + } +} + +func (sps *SegmentPrunerSuite) TestPruneSegmentsFloatTypes() { + paramtable.Init() + collectionName := "test_segment_prune" + primaryFieldName := "pk" + dim := 8 + var targetPartition int64 = 1 + const FLOAT = "float" + const DOUBLE = "double" + const VEC = "vec" + + fieldName2DataType := make(map[string]schemapb.DataType) + fieldName2DataType[primaryFieldName] = schemapb.DataType_Int64 + fieldName2DataType[FLOAT] = schemapb.DataType_Float + fieldName2DataType[DOUBLE] = schemapb.DataType_Double + fieldName2DataType[VEC] = schemapb.DataType_FloatVector + + { + // test for float cluster field + clusterFieldName := FLOAT + schema := testutil.ConstructCollectionSchemaWithKeys(collectionName, + fieldName2DataType, + primaryFieldName, + "", + clusterFieldName, + false, + dim) + + var clusteringKeyFieldID int64 = 0 + for _, field := range schema.GetFields() { + if field.IsClusteringKey { + clusteringKeyFieldID = field.FieldID + break + } + } + + // set up part stats + segStats := make(map[UniqueID]storage.SegmentStats) + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Float, + Min: storage.NewFloatFieldValue(-3.0), + Max: storage.NewFloatFieldValue(-1.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[1] = *storage.NewSegmentStats(fieldStats, 80) + } + + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Float, + Min: storage.NewFloatFieldValue(-0.5), + Max: storage.NewFloatFieldValue(2.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[2] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Float, + Min: storage.NewFloatFieldValue(2.5), + Max: storage.NewFloatFieldValue(5.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[3] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Float, + Min: storage.NewFloatFieldValue(5.5), + Max: storage.NewFloatFieldValue(8.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[4] = *storage.NewSegmentStats(fieldStats, 80) + } + partitionStats := make(map[UniqueID]*storage.PartitionStatsSnapshot) + partitionStats[targetPartition] = &storage.PartitionStatsSnapshot{ + SegmentStats: segStats, + } + sealedSegments := make([]SnapshotItem, 0) + item1 := SnapshotItem{ + NodeID: 1, + Segments: []SegmentEntry{ + { + NodeID: 1, + SegmentID: 1, + }, + { + NodeID: 1, + SegmentID: 2, + }, + }, + } + item2 := SnapshotItem{ + NodeID: 2, + Segments: []SegmentEntry{ + { + NodeID: 2, + SegmentID: 3, + }, + { + NodeID: 2, + SegmentID: 4, + }, + }, + } + sealedSegments = append(sealedSegments, item1) + sealedSegments = append(sealedSegments, item2) + + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "float > 3.5" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(0, len(sealedSegments[0].Segments)) + sps.Equal(2, len(sealedSegments[1].Segments)) + } + } + + { + // test for double cluster field + clusterFieldName := DOUBLE + schema := testutil.ConstructCollectionSchemaWithKeys(collectionName, + fieldName2DataType, + primaryFieldName, + "", + clusterFieldName, + false, + dim) + + var clusteringKeyFieldID int64 = 0 + for _, field := range schema.GetFields() { + if field.IsClusteringKey { + clusteringKeyFieldID = field.FieldID + break + } + } + + // set up part stats + segStats := make(map[UniqueID]storage.SegmentStats) + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Double, + Min: storage.NewDoubleFieldValue(-3.0), + Max: storage.NewDoubleFieldValue(-1.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[1] = *storage.NewSegmentStats(fieldStats, 80) + } + + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Double, + Min: storage.NewDoubleFieldValue(-0.5), + Max: storage.NewDoubleFieldValue(1.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[2] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Double, + Min: storage.NewDoubleFieldValue(1.5), + Max: storage.NewDoubleFieldValue(3.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[3] = *storage.NewSegmentStats(fieldStats, 80) + } + { + fieldStats := make([]storage.FieldStats, 0) + fieldStat1 := storage.FieldStats{ + FieldID: clusteringKeyFieldID, + Type: schemapb.DataType_Double, + Min: storage.NewDoubleFieldValue(4.0), + Max: storage.NewDoubleFieldValue(5.0), + } + fieldStats = append(fieldStats, fieldStat1) + segStats[4] = *storage.NewSegmentStats(fieldStats, 80) + } + partitionStats := make(map[UniqueID]*storage.PartitionStatsSnapshot) + partitionStats[targetPartition] = &storage.PartitionStatsSnapshot{ + SegmentStats: segStats, + } + sealedSegments := make([]SnapshotItem, 0) + item1 := SnapshotItem{ + NodeID: 1, + Segments: []SegmentEntry{ + { + NodeID: 1, + SegmentID: 1, + }, + { + NodeID: 1, + SegmentID: 2, + }, + }, + } + item2 := SnapshotItem{ + NodeID: 2, + Segments: []SegmentEntry{ + { + NodeID: 2, + SegmentID: 3, + }, + { + NodeID: 2, + SegmentID: 4, + }, + }, + } + sealedSegments = append(sealedSegments, item1) + sealedSegments = append(sealedSegments, item2) + + { + // test out bound int expr, fallback to common search + testSegments := make([]SnapshotItem, 0) + copy(testSegments, sealedSegments) + exprStr := "double < -1.5" + schemaHelper, _ := typeutil.CreateSchemaHelper(schema) + planNode, err := planparserv2.CreateRetrievePlan(schemaHelper, exprStr) + sps.NoError(err) + serializedPlan, _ := proto.Marshal(planNode) + queryReq := &internalpb.RetrieveRequest{ + SerializedExprPlan: serializedPlan, + } + PruneSegments(context.TODO(), partitionStats, nil, queryReq, schema, sealedSegments, PruneInfo{paramtable.Get().QueryNodeCfg.DefaultSegmentFilterRatio.GetAsFloat()}) + sps.Equal(1, len(sealedSegments[0].Segments)) + sps.Equal(0, len(sealedSegments[1].Segments)) + } + } +} + func TestSegmentPrunerSuite(t *testing.T) { suite.Run(t, new(SegmentPrunerSuite)) } diff --git a/internal/storage/field_value.go b/internal/storage/field_value.go index 3e6f0a0323..9c8e16fb34 100644 --- a/internal/storage/field_value.go +++ b/internal/storage/field_value.go @@ -19,11 +19,13 @@ package storage import ( "encoding/json" "fmt" + "math" "strings" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" ) type ScalarFieldValue interface { @@ -1029,17 +1031,43 @@ func (ifv *FloatVectorFieldValue) Size() int64 { return int64(len(ifv.Value) * 8) } -func NewScalarFieldValueFromGenericValue(dtype schemapb.DataType, gVal *planpb.GenericValue) ScalarFieldValue { +func NewScalarFieldValueFromGenericValue(dtype schemapb.DataType, gVal *planpb.GenericValue) (ScalarFieldValue, error) { switch dtype { + case schemapb.DataType_Int8: + i64Val := gVal.Val.(*planpb.GenericValue_Int64Val) + if i64Val.Int64Val > math.MaxInt8 || i64Val.Int64Val < math.MinInt8 { + return nil, merr.WrapErrParameterInvalidRange(math.MinInt8, math.MaxInt8, i64Val.Int64Val, "expr value out of bound") + } + return NewInt8FieldValue(int8(i64Val.Int64Val)), nil + + case schemapb.DataType_Int16: + i64Val := gVal.Val.(*planpb.GenericValue_Int64Val) + if i64Val.Int64Val > math.MaxInt16 || i64Val.Int64Val < math.MinInt16 { + return nil, merr.WrapErrParameterInvalidRange(math.MinInt16, math.MaxInt16, i64Val.Int64Val, "expr value out of bound") + } + return NewInt16FieldValue(int16(i64Val.Int64Val)), nil + + case schemapb.DataType_Int32: + i64Val := gVal.Val.(*planpb.GenericValue_Int64Val) + if i64Val.Int64Val > math.MaxInt32 || i64Val.Int64Val < math.MinInt32 { + return nil, merr.WrapErrParameterInvalidRange(math.MinInt32, math.MaxInt32, i64Val.Int64Val, "expr value out of bound") + } + return NewInt32FieldValue(int32(i64Val.Int64Val)), nil case schemapb.DataType_Int64: i64Val := gVal.Val.(*planpb.GenericValue_Int64Val) - return NewInt64FieldValue(i64Val.Int64Val) + return NewInt64FieldValue(i64Val.Int64Val), nil case schemapb.DataType_Float: floatVal := gVal.Val.(*planpb.GenericValue_FloatVal) - return NewFloatFieldValue(float32(floatVal.FloatVal)) - case schemapb.DataType_String, schemapb.DataType_VarChar: + return NewFloatFieldValue(float32(floatVal.FloatVal)), nil + case schemapb.DataType_Double: + floatVal := gVal.Val.(*planpb.GenericValue_FloatVal) + return NewDoubleFieldValue(floatVal.FloatVal), nil + case schemapb.DataType_String: strVal := gVal.Val.(*planpb.GenericValue_StringVal) - return NewStringFieldValue(strVal.StringVal) + return NewStringFieldValue(strVal.StringVal), nil + case schemapb.DataType_VarChar: + strVal := gVal.Val.(*planpb.GenericValue_StringVal) + return NewVarCharFieldValue(strVal.StringVal), nil default: // should not be reach panic(fmt.Sprintf("not supported datatype: %s", dtype.String()))