diff --git a/internal/mysqld/executor/executor.go b/internal/mysqld/executor/executor.go index e1a2f9b67e..47b4fe13a2 100644 --- a/internal/mysqld/executor/executor.go +++ b/internal/mysqld/executor/executor.go @@ -5,18 +5,17 @@ import ( "fmt" "strconv" + "github.com/milvus-io/milvus/pkg/common" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus/internal/mysqld/parser/antlrparser" "github.com/milvus-io/milvus/internal/util/typeutil" - "github.com/milvus-io/milvus-proto/go-api/schemapb" - "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus-proto/go-api/milvuspb" "github.com/milvus-io/milvus/pkg/util/commonpbutil" - querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" "github.com/milvus-io/milvus/internal/mysqld/planner" "github.com/milvus-io/milvus/internal/types" @@ -75,11 +74,6 @@ func (e *defaultExecutor) execSelect(ctx context.Context, n *planner.NodeSelectS q := stmt.Query.Unwrap() - if q.Limit.IsSome() { - // TODO: use pagination. - return nil, fmt.Errorf("invalid query statement, limit/offset is not supported") - } - if len(q.SelectSpecs) != 0 { return nil, fmt.Errorf("invalid query statement, select spec is not supported") } @@ -117,6 +111,16 @@ func (e *defaultExecutor) execSelect(ctx context.Context, n *planner.NodeSelectS // `match` is false. + if q.Anns.IsSome() { + // reuse the parsed `outputFields`. + return e.execANNS(ctx, q, outputFields) + } + + if q.Limit.IsSome() { + // TODO: use pagination. + return nil, fmt.Errorf("invalid query statement, limit/offset is not supported") + } + if !from.Where.IsSome() { // query without filter. return nil, fmt.Errorf("query without filter is not supported") } @@ -129,42 +133,33 @@ func (e *defaultExecutor) execSelect(ctx context.Context, n *planner.NodeSelectS return wrapQueryResults(res), nil } -func getOutputFieldsOrMatchCountRule(fields []*planner.NodeSelectElement) (outputFields []string, match bool, err error) { - match = false - l := len(fields) - - if l == 1 { - entry := fields[0] - match = entry.FunctionCall.IsSome() && - entry.FunctionCall.Unwrap().Agg.IsSome() && - entry.FunctionCall.Unwrap().Agg.Unwrap().AggCount.IsSome() +func (e *defaultExecutor) execANNS(ctx context.Context, q *planner.NodeQuerySpecification, outputs []string) (*sqltypes.Result, error) { + if !q.Limit.IsSome() { + return nil, fmt.Errorf("limit not specified in the ANNS statement") } - if match { - return nil, match, nil + annsClause := q.Anns.Unwrap() + + searchParams := prepareSearchParams(q.Limit.Unwrap(), annsClause) + + outputsIndex, userOutputs := generateOutputsIndex(outputs) + + filter := restoreExpr(q.From.Unwrap()) + + tableName := q.From.Unwrap().TableSources[0].TableName.Unwrap() + + req := prepareSearchReq(tableName, filter, annsClause.Vectors, userOutputs, searchParams) + + res, err := e.s.Search(ctx, req) + if err != nil { + return nil, err } - outputFields = make([]string, 0, l) - for _, entry := range fields { - if entry.Star.IsSome() { - // TODO: support `select *`. - return nil, match, fmt.Errorf("* is not supported") - } - - if entry.FunctionCall.IsSome() { - return nil, match, fmt.Errorf("combined select elements is not supported") - } - - if entry.FullColumnName.IsSome() { - c := entry.FullColumnName.Unwrap() - if c.Alias.IsSome() { - return nil, match, fmt.Errorf("alias for select elements is not supported") - } - outputFields = append(outputFields, c.Name) - } + if res.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { + return nil, common.NewStatusError(res.GetStatus().GetErrorCode(), res.GetStatus().GetReason()) } - return outputFields, false, nil + return wrapSearchResult(res, outputsIndex, userOutputs), nil } func (e *defaultExecutor) execCountWithFilter(ctx context.Context, tableName string, filter string) (*sqltypes.Result, error) { @@ -229,126 +224,6 @@ func (e *defaultExecutor) execQuery(ctx context.Context, tableName string, filte return resp, nil } -func wrapCountResult(rowCnt int, column string) *sqltypes.Result { - result1 := &sqltypes.Result{ - Fields: []*querypb.Field{ - { - Name: column, - Type: querypb.Type_INT64, - }, - }, - Rows: [][]sqltypes.Value{ - { - sqltypes.NewInt64(int64(rowCnt)), - }, - }, - } - return result1 -} - -func wrapQueryResults(res *milvuspb.QueryResults) *sqltypes.Result { - fieldsData := res.GetFieldsData() - nColumn := len(fieldsData) - fields := make([]*querypb.Field, 0, nColumn) - - if nColumn <= 0 { - return &sqltypes.Result{} - } - - for i := 0; i < nColumn; i++ { - fields = append(fields, getSQLField(res.GetCollectionName(), fieldsData[i])) - } - - nRow := typeutil.GetRowCount(fieldsData[0]) - rows := make([][]sqltypes.Value, 0, nRow) - for i := 0; i < nRow; i++ { - row := make([]sqltypes.Value, 0, nColumn) - for j := 0; j < nColumn; j++ { - row = append(row, getDataSingle(fieldsData[j], i)) - } - rows = append(rows, row) - } - - return &sqltypes.Result{ - Fields: fields, - Rows: rows, - } -} - -func getSQLField(tableName string, fieldData *schemapb.FieldData) *querypb.Field { - return &querypb.Field{ - Name: fieldData.GetFieldName(), - Type: toSQLType(fieldData.GetType()), - Table: tableName, - OrgTable: "", - Database: "", - OrgName: "", - ColumnLength: 0, - Charset: 0, - Decimals: 0, - Flags: 0, - } -} - -func toSQLType(t schemapb.DataType) querypb.Type { - switch t { - case schemapb.DataType_Bool: - // TODO: tinyint - return querypb.Type_UINT8 - case schemapb.DataType_Int8: - return querypb.Type_INT8 - case schemapb.DataType_Int16: - return querypb.Type_INT16 - case schemapb.DataType_Int32: - return querypb.Type_INT32 - case schemapb.DataType_Int64: - return querypb.Type_INT64 - case schemapb.DataType_Float: - return querypb.Type_FLOAT32 - case schemapb.DataType_Double: - return querypb.Type_FLOAT64 - case schemapb.DataType_VarChar: - return querypb.Type_VARCHAR - // TODO: vector. - default: - return querypb.Type_NULL_TYPE - } -} - -func getDataSingle(fieldData *schemapb.FieldData, idx int) sqltypes.Value { - switch fieldData.GetType() { - case schemapb.DataType_Bool: - // TODO: tinyint - return sqltypes.NewInt32(1) - case schemapb.DataType_Int8: - v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx] - return sqltypes.MakeTrusted(sqltypes.Int8, strconv.AppendInt(nil, int64(v), 10)) - case schemapb.DataType_Int16: - v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx] - return sqltypes.MakeTrusted(sqltypes.Int16, strconv.AppendInt(nil, int64(v), 10)) - case schemapb.DataType_Int32: - v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx] - return sqltypes.MakeTrusted(sqltypes.Int32, strconv.AppendInt(nil, int64(v), 10)) - case schemapb.DataType_Int64: - v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_LongData).LongData.GetData()[idx] - return sqltypes.MakeTrusted(sqltypes.Int64, strconv.AppendInt(nil, v, 10)) - case schemapb.DataType_Float: - v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_FloatData).FloatData.GetData()[idx] - return sqltypes.MakeTrusted(sqltypes.Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64)) - case schemapb.DataType_Double: - v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_DoubleData).DoubleData.GetData()[idx] - return sqltypes.MakeTrusted(sqltypes.Float64, strconv.AppendFloat(nil, v, 'g', -1, 64)) - case schemapb.DataType_VarChar: - v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_StringData).StringData.GetData()[idx] - return sqltypes.NewVarChar(v) - - // TODO: vector. - default: - // TODO: should raise error here. - return sqltypes.NewInt32(1) - } -} - func NewDefaultExecutor(s types.ProxyComponent) Executor { return &defaultExecutor{s: s} } diff --git a/internal/mysqld/executor/executor_test.go b/internal/mysqld/executor/executor_test.go index 480adc2fd4..e0202cc7b5 100644 --- a/internal/mysqld/executor/executor_test.go +++ b/internal/mysqld/executor/executor_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "github.com/milvus-io/milvus/internal/mysqld/parser/antlrparser" + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-proto/go-api/schemapb" @@ -445,63 +447,6 @@ func Test_defaultExecutor_execSelect(t *testing.T) { }) } -func Test_getOutputFieldsOrMatchCountRule(t *testing.T) { - t.Run("match count rule", func(t *testing.T) { - fl := []*planner.NodeSelectElement{ - planner.NewNodeSelectElement("", planner.WithFunctionCall( - planner.NewNodeFunctionCall("", planner.WithAgg( - planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount( - planner.NewNodeCount(""))))))), - } - _, match, err := getOutputFieldsOrMatchCountRule(fl) - assert.NoError(t, err) - assert.True(t, match) - }) - - t.Run("star *, not supported", func(t *testing.T) { - fl := []*planner.NodeSelectElement{ - planner.NewNodeSelectElement("", planner.WithStar()), - } - _, _, err := getOutputFieldsOrMatchCountRule(fl) - assert.Error(t, err) - }) - - t.Run("combined", func(t *testing.T) { - fl := []*planner.NodeSelectElement{ - planner.NewNodeSelectElement("", planner.WithFunctionCall( - planner.NewNodeFunctionCall("", planner.WithAgg( - planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount( - planner.NewNodeCount(""))))))), - planner.NewNodeSelectElement("", planner.WithFullColumnName( - planner.NewNodeFullColumnName("", "field"))), - } - _, _, err := getOutputFieldsOrMatchCountRule(fl) - assert.Error(t, err) - }) - - t.Run("alias, not supported", func(t *testing.T) { - fl := []*planner.NodeSelectElement{ - planner.NewNodeSelectElement("", planner.WithFullColumnName( - planner.NewNodeFullColumnName("", "field", planner.FullColumnNameWithAlias("alias")))), - } - _, _, err := getOutputFieldsOrMatchCountRule(fl) - assert.Error(t, err) - }) - - t.Run("normal case", func(t *testing.T) { - fl := []*planner.NodeSelectElement{ - planner.NewNodeSelectElement("", planner.WithFullColumnName( - planner.NewNodeFullColumnName("", "field1"))), - planner.NewNodeSelectElement("", planner.WithFullColumnName( - planner.NewNodeFullColumnName("", "field2"))), - } - outputFields, match, err := getOutputFieldsOrMatchCountRule(fl) - assert.NoError(t, err) - assert.False(t, match) - assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields) - }) -} - func Test_defaultExecutor_execCountWithFilter(t *testing.T) { t.Run("failed to query", func(t *testing.T) { s := mocks.NewProxyComponent(t) @@ -593,129 +538,67 @@ func Test_defaultExecutor_execQuery(t *testing.T) { }) } -func Test_wrapCountResult(t *testing.T) { - sqlRes := wrapCountResult(100, "count(*)") - assert.Equal(t, 1, len(sqlRes.Fields)) - assert.Equal(t, "count(*)", sqlRes.Fields[0].Name) - assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type) - assert.Equal(t, 1, len(sqlRes.Rows)) - assert.Equal(t, 1, len(sqlRes.Rows[0])) - assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type()) -} - -func Test_wrapQueryResults(t *testing.T) { - res := &milvuspb.QueryResults{ - Status: &commonpb.Status{}, - FieldsData: []*schemapb.FieldData{ - { - Type: schemapb.DataType_Int64, - FieldName: "field", - Field: &schemapb.FieldData_Scalars{ - Scalars: &schemapb.ScalarField{ - Data: &schemapb.ScalarField_LongData{ - LongData: &schemapb.LongArray{ - Data: []int64{1, 2, 3, 4}, - }, - }, +func Test_defaultExecutor_execANNS(t *testing.T) { + f1 := &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: "pk", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{6, 5, 4, 3, 2, 1}, + }, + }, + }, + }, + } + f2 := &schemapb.FieldData{ + Type: schemapb.DataType_Float, + FieldName: "random", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{6.6, 5.5, 4.4, 3.3, 2.2, 1.1}, }, }, }, }, - CollectionName: "test", } - sqlRes := wrapQueryResults(res) - assert.Equal(t, 1, len(sqlRes.Fields)) - assert.Equal(t, 4, len(sqlRes.Rows)) - assert.Equal(t, "field", sqlRes.Fields[0].Name) - assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type) - assert.Equal(t, 1, len(sqlRes.Rows[0])) - assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type()) -} -func Test_getSQLField(t *testing.T) { - f := &schemapb.FieldData{ - FieldName: "a", - Type: schemapb.DataType_Int64, + res := &milvuspb.SearchResults{ + Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, + Results: &schemapb.SearchResultData{ + NumQueries: 2, + TopK: 3, + FieldsData: []*schemapb.FieldData{f1, f2}, + Scores: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6}, + Topks: []int64{2, 2}, + }, + CollectionName: "hello_milvus", } - sf := getSQLField("t", f) - assert.Equal(t, "a", sf.Name) - assert.Equal(t, querypb.Type_INT64, sf.Type) - assert.Equal(t, "t", sf.Table) -} + s := mocks.NewProxyComponent(t) + s.On("Search", + mock.Anything, // context.Context + mock.Anything, // *milvuspb.SearchRequest + ).Return(res, nil) -func Test_toSQLType(t *testing.T) { - type args struct { - t schemapb.DataType - } - tests := []struct { - name string - args args - want querypb.Type - }{ - { - args: args{ - t: schemapb.DataType_Bool, - }, - want: querypb.Type_UINT8, - }, - { - args: args{ - t: schemapb.DataType_Int8, - }, - want: querypb.Type_INT8, - }, - { - args: args{ - t: schemapb.DataType_Int16, - }, - want: querypb.Type_INT16, - }, - { - args: args{ - t: schemapb.DataType_Int32, - }, - want: querypb.Type_INT32, - }, - { - args: args{ - t: schemapb.DataType_Int64, - }, - want: querypb.Type_INT64, - }, - { - args: args{ - t: schemapb.DataType_Float, - }, - want: querypb.Type_FLOAT32, - }, - { - args: args{ - t: schemapb.DataType_Double, - }, - want: querypb.Type_FLOAT64, - }, - { - args: args{ - t: schemapb.DataType_VarChar, - }, - want: querypb.Type_VARCHAR, - }, - { - args: args{ - t: schemapb.DataType_FloatVector, - }, - want: querypb.Type_NULL_TYPE, - }, - { - args: args{ - t: schemapb.DataType_BinaryVector, - }, - want: querypb.Type_NULL_TYPE, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t) - }) - } + sql := ` +select +$query_number, pk, random, $distance +from hello_milvus +where random > 0.5 +anns by embeddings -> ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], [0.8, 0.7, 0.6, 0.6, 0.4, 0.3, 0.2, 0.1]) +params = (metric_type=L2, nprobe=10) +limit 3 +` + plan, warns, err := antlrparser.NewAntlrParser().Parse(sql) + assert.NoError(t, err) + assert.Nil(t, warns) + + e := NewDefaultExecutor(s).(*defaultExecutor) + _, err = e.execANNS(context.TODO(), + antlrparser.GetSqlStatements(plan.Node).Statements[0].DmlStatement.Unwrap().SelectStatement.Unwrap().SimpleSelect.Unwrap().Query.Unwrap(), + []string{"$query_number", "pk", "random", "$distance"}) + assert.NoError(t, err) } diff --git a/internal/mysqld/executor/utils.go b/internal/mysqld/executor/utils.go new file mode 100644 index 0000000000..1d6a9f6c40 --- /dev/null +++ b/internal/mysqld/executor/utils.go @@ -0,0 +1,363 @@ +package executor + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/golang/protobuf/proto" + + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/util/typeutil" + querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" + "github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus/pkg/util/funcutil" + + "github.com/milvus-io/milvus/internal/mysqld/planner" + "github.com/milvus-io/milvus/pkg/common" +) + +func prepareSearchParams(limitClause *planner.NodeLimitClause, annsClause *planner.NodeANNSClause) map[string]string { + searchParams := map[string]string{ + common.AnnsFieldKey: annsClause.Column.Name, + common.TopKKey: fmt.Sprintf("%d", limitClause.Limit), + common.OffsetKey: fmt.Sprintf("%d", limitClause.Offset), + } + + if annsClause.Params.IsSome() { + for k, v := range annsClause.Params.Unwrap().KVs { + searchParams[k] = v + } + params, _ := json.Marshal(annsClause.Params.Unwrap().KVs) + searchParams[common.SearchParamsKey] = string(params) + } + + return searchParams +} + +func generateOutputsIndex(outputs []string) (index map[string]int, users []string) { + index = make(map[string]int, len(outputs)) + users = make([]string, 0, len(outputs)) + + fix := func(s string) string { + return strings.ToLower(strings.TrimSpace(s)) + } + + for i, f := range outputs { + fixed := fix(f) + if fixed != common.QueryNumberKey && fixed != common.DistanceKey { + users = append(users, f) + index[f] = i + } else { + index[fixed] = i + } + } + + return index, users +} + +func restoreExpr(from *planner.NodeFromClause) string { + filter := "" + if from.Where.IsSome() { + filter = planner.NewExprTextRestorer().RestoreExprText(from.Where.Unwrap()) + } + return filter +} + +func vector2PlaceholderGroupBytes(vectors []*planner.NodeVector) []byte { + phg := &commonpb.PlaceholderGroup{ + Placeholders: []*commonpb.PlaceholderValue{ + vector2Placeholder(vectors), + }, + } + + bs, _ := proto.Marshal(phg) + return bs +} + +func vector2Placeholder(vectors []*planner.NodeVector) *commonpb.PlaceholderValue { + var placeHolderType commonpb.PlaceholderType + + ph := &commonpb.PlaceholderValue{ + Tag: "$0", + Values: make([][]byte, 0, len(vectors)), + } + + if len(vectors) == 0 { + return ph + } + + if vectors[0].FloatVector.IsSome() { + placeHolderType = commonpb.PlaceholderType_FloatVector + } else { + placeHolderType = commonpb.PlaceholderType_BinaryVector + + } + + ph.Type = placeHolderType + for _, vector := range vectors { + ph.Values = append(ph.Values, vector.Serialize()) + } + + return ph +} + +func prepareSearchReq(tableName string, filter string, vectors []*planner.NodeVector, userOutputs []string, searchParams map[string]string) *milvuspb.SearchRequest { + phg := vector2PlaceholderGroupBytes(vectors) + req := &milvuspb.SearchRequest{ + CollectionName: tableName, + PartitionNames: nil, + Dsl: filter, + PlaceholderGroup: phg, + DslType: commonpb.DslType_BoolExprV1, + OutputFields: userOutputs, + SearchParams: funcutil.Map2KeyValuePair(searchParams), + TravelTimestamp: 0, + GuaranteeTimestamp: 2, // default bounded consistency level. + Nq: int64(len(vectors)), + } + return req +} + +func wrapFieldsData(collectionName string, fieldsData []*schemapb.FieldData) *sqltypes.Result { + nColumn := len(fieldsData) + fields := make([]*querypb.Field, 0, nColumn) + + if nColumn <= 0 { + return &sqltypes.Result{} + } + + for i := 0; i < nColumn; i++ { + fields = append(fields, getSQLField(collectionName, fieldsData[i])) + } + + nRow := typeutil.GetRowCount(fieldsData[0]) + rows := make([][]sqltypes.Value, 0, nRow) + for i := 0; i < nRow; i++ { + row := make([]sqltypes.Value, 0, nColumn) + for j := 0; j < nColumn; j++ { + row = append(row, getDataSingle(fieldsData[j], i)) + } + rows = append(rows, row) + } + + return &sqltypes.Result{ + Fields: fields, + Rows: rows, + } +} + +func wrapQueryResults(res *milvuspb.QueryResults) *sqltypes.Result { + return wrapFieldsData(res.GetCollectionName(), res.GetFieldsData()) +} + +func wrapCountResult(rowCnt int, column string) *sqltypes.Result { + result1 := &sqltypes.Result{ + Fields: []*querypb.Field{ + { + Name: column, + Type: querypb.Type_INT64, + }, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.NewInt64(int64(rowCnt)), + }, + }, + } + return result1 +} + +func getSQLField(tableName string, fieldData *schemapb.FieldData) *querypb.Field { + return &querypb.Field{ + Name: fieldData.GetFieldName(), + Type: toSQLType(fieldData.GetType()), + Table: tableName, + OrgTable: "", + Database: "", + OrgName: "", + ColumnLength: 0, + Charset: 0, + Decimals: 0, + Flags: 0, + } +} + +func toSQLType(t schemapb.DataType) querypb.Type { + switch t { + case schemapb.DataType_Bool: + // TODO: tinyint + return querypb.Type_UINT8 + case schemapb.DataType_Int8: + return querypb.Type_INT8 + case schemapb.DataType_Int16: + return querypb.Type_INT16 + case schemapb.DataType_Int32: + return querypb.Type_INT32 + case schemapb.DataType_Int64: + return querypb.Type_INT64 + case schemapb.DataType_Float: + return querypb.Type_FLOAT32 + case schemapb.DataType_Double: + return querypb.Type_FLOAT64 + case schemapb.DataType_VarChar: + return querypb.Type_VARCHAR + // TODO: vector. + default: + return querypb.Type_NULL_TYPE + } +} + +func getDataSingle(fieldData *schemapb.FieldData, idx int) sqltypes.Value { + switch fieldData.GetType() { + case schemapb.DataType_Bool: + // TODO: tinyint + return sqltypes.NewInt32(1) + case schemapb.DataType_Int8: + v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx] + return sqltypes.MakeTrusted(sqltypes.Int8, strconv.AppendInt(nil, int64(v), 10)) + case schemapb.DataType_Int16: + v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx] + return sqltypes.MakeTrusted(sqltypes.Int16, strconv.AppendInt(nil, int64(v), 10)) + case schemapb.DataType_Int32: + v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx] + return sqltypes.MakeTrusted(sqltypes.Int32, strconv.AppendInt(nil, int64(v), 10)) + case schemapb.DataType_Int64: + v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_LongData).LongData.GetData()[idx] + return sqltypes.MakeTrusted(sqltypes.Int64, strconv.AppendInt(nil, v, 10)) + case schemapb.DataType_Float: + v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_FloatData).FloatData.GetData()[idx] + return sqltypes.MakeTrusted(sqltypes.Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64)) + case schemapb.DataType_Double: + v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_DoubleData).DoubleData.GetData()[idx] + return sqltypes.MakeTrusted(sqltypes.Float64, strconv.AppendFloat(nil, v, 'g', -1, 64)) + case schemapb.DataType_VarChar: + v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_StringData).StringData.GetData()[idx] + return sqltypes.NewVarChar(v) + + // TODO: vector. + default: + // TODO: should raise error here. + return sqltypes.NewInt32(1) + } +} + +func wrapSearchResult(res *milvuspb.SearchResults, outputsIndex map[string]int, userOutputs []string) *sqltypes.Result { + qnIndex, qnOk := outputsIndex[common.QueryNumberKey] + disIndex, disOk := outputsIndex[common.DistanceKey] + + fieldsData := make([]*schemapb.FieldData, len(res.GetResults().GetFieldsData())) + copy(fieldsData, res.GetResults().GetFieldsData()) + + insert := func(fields []*schemapb.FieldData, field *schemapb.FieldData, index int) []*schemapb.FieldData { + if index >= len(fields) { + return append(fields, field) + } + return append(fieldsData[:index], append([]*schemapb.FieldData{field}, fieldsData[index:]...)...) + } + + if qnOk && disOk { + qnField := generateQueryNumberFieldData(res) + disField := generateDistanceFieldsData(res) + + index1, index2 := qnIndex, disIndex + f1, f2 := qnField, disField + if qnIndex > disIndex { + index1, index2 = disIndex, qnIndex + f1, f2 = disField, qnField + } + + fieldsData = insert(fieldsData, f1, index1) + fieldsData = insert(fieldsData, f2, index2) + } else if qnOk { + qnField := generateQueryNumberFieldData(res) + fieldsData = insert(fieldsData, qnField, qnIndex) + } else if disOk { + disField := generateDistanceFieldsData(res) + fieldsData = insert(fieldsData, disField, disIndex) + } + + return wrapFieldsData(res.GetCollectionName(), fieldsData) +} + +func generateQueryNumberFieldData(res *milvuspb.SearchResults) *schemapb.FieldData { + arr := make([]int64, 0, res.GetResults().GetNumQueries()*res.GetResults().GetTopK()) + + for i, topk := range res.GetResults().GetTopks() { + for j := 0; int64(j) < topk; j++ { + arr = append(arr, int64(i)) + } + } + + return &schemapb.FieldData{ + Type: schemapb.DataType_Int64, + FieldName: common.QueryNumberKey, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: arr, + }, + }, + }, + }, + } +} + +func generateDistanceFieldsData(res *milvuspb.SearchResults) *schemapb.FieldData { + return &schemapb.FieldData{ + Type: schemapb.DataType_Float, + FieldName: common.DistanceKey, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: res.GetResults().GetScores(), + }, + }, + }, + }, + } +} + +func getOutputFieldsOrMatchCountRule(fields []*planner.NodeSelectElement) (outputFields []string, match bool, err error) { + match = false + l := len(fields) + + if l == 1 { + entry := fields[0] + match = entry.FunctionCall.IsSome() && + entry.FunctionCall.Unwrap().Agg.IsSome() && + entry.FunctionCall.Unwrap().Agg.Unwrap().AggCount.IsSome() + } + + if match { + return nil, match, nil + } + + outputFields = make([]string, 0, l) + for _, entry := range fields { + if entry.Star.IsSome() { + // TODO: support `select *`. + return nil, match, fmt.Errorf("* is not supported") + } + + if entry.FunctionCall.IsSome() { + return nil, match, fmt.Errorf("combined select elements is not supported") + } + + if entry.FullColumnName.IsSome() { + c := entry.FullColumnName.Unwrap() + if c.Alias.IsSome() { + return nil, match, fmt.Errorf("alias for select elements is not supported") + } + outputFields = append(outputFields, c.Name) + } + } + + return outputFields, false, nil +} diff --git a/internal/mysqld/executor/utils_test.go b/internal/mysqld/executor/utils_test.go index aae3d863af..8385a4597b 100644 --- a/internal/mysqld/executor/utils_test.go +++ b/internal/mysqld/executor/utils_test.go @@ -1,6 +1,15 @@ package executor -import "github.com/milvus-io/milvus/internal/mysqld/planner" +import ( + "testing" + + "github.com/milvus-io/milvus-proto/go-api/commonpb" + "github.com/milvus-io/milvus-proto/go-api/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/schemapb" + "github.com/milvus-io/milvus/internal/mysqld/planner" + "github.com/stretchr/testify/assert" + "github.com/xelabs/go-mysqlstack/sqlparser/depends/query" +) func GenNodeExpression(field string, c int64, op planner.ComparisonOperator) *planner.NodeExpression { n := planner.NewNodeExpression("", planner.WithPredicate( @@ -20,3 +29,187 @@ func GenNodeExpression(field string, c int64, op planner.ComparisonOperator) *pl ) return n } + +func Test_getOutputFieldsOrMatchCountRule(t *testing.T) { + t.Run("match count rule", func(t *testing.T) { + fl := []*planner.NodeSelectElement{ + planner.NewNodeSelectElement("", planner.WithFunctionCall( + planner.NewNodeFunctionCall("", planner.WithAgg( + planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount( + planner.NewNodeCount(""))))))), + } + _, match, err := getOutputFieldsOrMatchCountRule(fl) + assert.NoError(t, err) + assert.True(t, match) + }) + + t.Run("star *, not supported", func(t *testing.T) { + fl := []*planner.NodeSelectElement{ + planner.NewNodeSelectElement("", planner.WithStar()), + } + _, _, err := getOutputFieldsOrMatchCountRule(fl) + assert.Error(t, err) + }) + + t.Run("combined", func(t *testing.T) { + fl := []*planner.NodeSelectElement{ + planner.NewNodeSelectElement("", planner.WithFunctionCall( + planner.NewNodeFunctionCall("", planner.WithAgg( + planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount( + planner.NewNodeCount(""))))))), + planner.NewNodeSelectElement("", planner.WithFullColumnName( + planner.NewNodeFullColumnName("", "field"))), + } + _, _, err := getOutputFieldsOrMatchCountRule(fl) + assert.Error(t, err) + }) + + t.Run("alias, not supported", func(t *testing.T) { + fl := []*planner.NodeSelectElement{ + planner.NewNodeSelectElement("", planner.WithFullColumnName( + planner.NewNodeFullColumnName("", "field", planner.FullColumnNameWithAlias("alias")))), + } + _, _, err := getOutputFieldsOrMatchCountRule(fl) + assert.Error(t, err) + }) + + t.Run("normal case", func(t *testing.T) { + fl := []*planner.NodeSelectElement{ + planner.NewNodeSelectElement("", planner.WithFullColumnName( + planner.NewNodeFullColumnName("", "field1"))), + planner.NewNodeSelectElement("", planner.WithFullColumnName( + planner.NewNodeFullColumnName("", "field2"))), + } + outputFields, match, err := getOutputFieldsOrMatchCountRule(fl) + assert.NoError(t, err) + assert.False(t, match) + assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields) + }) +} + +func Test_wrapCountResult(t *testing.T) { + sqlRes := wrapCountResult(100, "count(*)") + assert.Equal(t, 1, len(sqlRes.Fields)) + assert.Equal(t, "count(*)", sqlRes.Fields[0].Name) + assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type) + assert.Equal(t, 1, len(sqlRes.Rows)) + assert.Equal(t, 1, len(sqlRes.Rows[0])) + assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type()) +} + +func Test_wrapQueryResults(t *testing.T) { + res := &milvuspb.QueryResults{ + Status: &commonpb.Status{}, + FieldsData: []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: "field", + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + }, + }, + }, + }, + CollectionName: "test", + } + sqlRes := wrapQueryResults(res) + assert.Equal(t, 1, len(sqlRes.Fields)) + assert.Equal(t, 4, len(sqlRes.Rows)) + assert.Equal(t, "field", sqlRes.Fields[0].Name) + assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type) + assert.Equal(t, 1, len(sqlRes.Rows[0])) + assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type()) +} + +func Test_getSQLField(t *testing.T) { + f := &schemapb.FieldData{ + FieldName: "a", + Type: schemapb.DataType_Int64, + } + sf := getSQLField("t", f) + assert.Equal(t, "a", sf.Name) + assert.Equal(t, query.Type_INT64, sf.Type) + assert.Equal(t, "t", sf.Table) +} + +func Test_toSQLType(t *testing.T) { + type args struct { + t schemapb.DataType + } + tests := []struct { + name string + args args + want query.Type + }{ + { + args: args{ + t: schemapb.DataType_Bool, + }, + want: query.Type_UINT8, + }, + { + args: args{ + t: schemapb.DataType_Int8, + }, + want: query.Type_INT8, + }, + { + args: args{ + t: schemapb.DataType_Int16, + }, + want: query.Type_INT16, + }, + { + args: args{ + t: schemapb.DataType_Int32, + }, + want: query.Type_INT32, + }, + { + args: args{ + t: schemapb.DataType_Int64, + }, + want: query.Type_INT64, + }, + { + args: args{ + t: schemapb.DataType_Float, + }, + want: query.Type_FLOAT32, + }, + { + args: args{ + t: schemapb.DataType_Double, + }, + want: query.Type_FLOAT64, + }, + { + args: args{ + t: schemapb.DataType_VarChar, + }, + want: query.Type_VARCHAR, + }, + { + args: args{ + t: schemapb.DataType_FloatVector, + }, + want: query.Type_NULL_TYPE, + }, + { + args: args{ + t: schemapb.DataType_BinaryVector, + }, + want: query.Type_NULL_TYPE, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t) + }) + } +} diff --git a/internal/mysqld/planner/float_vector.go b/internal/mysqld/planner/float_vector.go index a9c7a81e24..e656fda34c 100644 --- a/internal/mysqld/planner/float_vector.go +++ b/internal/mysqld/planner/float_vector.go @@ -1,9 +1,25 @@ package planner +import ( + "math" + + "github.com/milvus-io/milvus/pkg/common" +) + type NodeFloatVector struct { Array []float32 } +func (n NodeFloatVector) Serialize() []byte { + data := make([]byte, 0, 4*len(n.Array)) // float32 occupies 4 bytes + buf := make([]byte, 4) + for _, f := range n.Array { + common.Endian.PutUint32(buf, math.Float32bits(f)) + data = append(data, buf...) + } + return data +} + func NewNodeFloatVector(arr []float32) *NodeFloatVector { return &NodeFloatVector{ Array: arr, diff --git a/internal/mysqld/planner/vector.go b/internal/mysqld/planner/vector.go index 5fcf3c7cd7..46e68d2b3c 100644 --- a/internal/mysqld/planner/vector.go +++ b/internal/mysqld/planner/vector.go @@ -14,6 +14,13 @@ func (n *NodeVector) apply(opts ...NodeVectorOption) { } } +func (n *NodeVector) Serialize() []byte { + if n.FloatVector.IsSome() { + return n.FloatVector.Unwrap().Serialize() + } + return nil +} + func WithFloatVector(v *NodeFloatVector) NodeVectorOption { return func(n *NodeVector) { n.FloatVector = optional.Some(v) diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 897c928011..3d9c6af86a 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -42,16 +42,6 @@ import ( ) const ( - IgnoreGrowingKey = "ignore_growing" - AnnsFieldKey = "anns_field" - TopKKey = "topk" - NQKey = "nq" - MetricTypeKey = "metric_type" - SearchParamsKey = "params" - RoundDecimalKey = "round_decimal" - OffsetKey = "offset" - LimitKey = "limit" - InsertTaskName = "InsertTask" CreateCollectionTaskName = "CreateCollectionTask" DropCollectionTaskName = "DropCollectionTask" diff --git a/internal/proxy/task_index_test.go b/internal/proxy/task_index_test.go index 0fc0ff608d..0237100317 100644 --- a/internal/proxy/task_index_test.go +++ b/internal/proxy/task_index_test.go @@ -21,6 +21,8 @@ import ( "os" "testing" + "github.com/milvus-io/milvus/pkg/common" + "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -287,7 +289,7 @@ func Test_parseIndexParams(t *testing.T) { Value: "HNSW", }, { - Key: MetricTypeKey, + Key: common.MetricTypeKey, Value: "IP", }, { @@ -321,7 +323,7 @@ func Test_parseIndexParams(t *testing.T) { Value: "128", }, { - Key: MetricTypeKey, + Key: common.MetricTypeKey, Value: "L2", }, }}, @@ -338,7 +340,7 @@ func Test_parseIndexParams(t *testing.T) { Value: "HNSW", }, { - Key: MetricTypeKey, + Key: common.MetricTypeKey, Value: "IP", }, { diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 1d0831a324..c706b3c422 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -134,32 +134,32 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e err error ) - limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair) + limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.LimitKey, queryParamsPair) // if limit is not provided if err != nil { return &queryParams{limit: typeutil.Unlimited}, nil } limit, err = strconv.ParseInt(limitStr, 0, 64) if err != nil { - return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr) + return nil, fmt.Errorf("%s [%s] is invalid", common.LimitKey, limitStr) } if limit != 0 { if err := validateLimit(limit); err != nil { - return nil, fmt.Errorf("%s [%d] is invalid, %w", LimitKey, limit, err) + return nil, fmt.Errorf("%s [%d] is invalid, %w", common.LimitKey, limit, err) } } - offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, queryParamsPair) + offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.OffsetKey, queryParamsPair) // if offset is provided if err == nil { offset, err = strconv.ParseInt(offsetStr, 0, 64) if err != nil { - return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr) + return nil, fmt.Errorf("%s [%s] is invalid", common.OffsetKey, offsetStr) } if offset != 0 { if err := validateLimit(offset); err != nil { - return nil, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err) + return nil, fmt.Errorf("%s [%d] is invalid, %w", common.OffsetKey, offset, err) } } } @@ -289,7 +289,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error { //fetch search_growing from search param var ignoreGrowing bool for i, kv := range t.request.GetQueryParams() { - if kv.GetKey() == IgnoreGrowingKey { + if kv.GetKey() == common.IgnoreGrowingKey { ignoreGrowing, err = strconv.ParseBool(kv.Value) if err != nil { return errors.New("parse search growing failed") diff --git a/internal/proxy/task_query_test.go b/internal/proxy/task_query_test.go index b52cd58a46..493e7e4295 100644 --- a/internal/proxy/task_query_test.go +++ b/internal/proxy/task_query_test.go @@ -154,7 +154,7 @@ func TestQueryTask_all(t *testing.T) { Expr: expr, QueryParams: []*commonpb.KeyValuePair{ { - Key: IgnoreGrowingKey, + Key: common.IgnoreGrowingKey, Value: "false", }, }, @@ -392,16 +392,16 @@ func TestTaskQuery_functions(t *testing.T) { outOffset int64 }{ {"empty input", []string{}, []string{}, false, typeutil.Unlimited, 0}, - {"valid limit=1", []string{LimitKey}, []string{"1"}, false, 1, 0}, - {"valid limit=1, offset=2", []string{LimitKey, OffsetKey}, []string{"1", "2"}, false, 1, 2}, - {"valid no limit, offset=2", []string{OffsetKey}, []string{"2"}, false, typeutil.Unlimited, 0}, - {"invalid limit str", []string{LimitKey}, []string{"a"}, true, 0, 0}, - {"invalid limit zero", []string{LimitKey}, []string{"0"}, true, 0, 0}, - {"invalid limit negative", []string{LimitKey}, []string{"-1"}, true, 0, 0}, - {"invalid limit 16385", []string{LimitKey}, []string{"16385"}, true, 0, 0}, - {"invalid offset negative", []string{LimitKey, OffsetKey}, []string{"1", "-1"}, true, 0, 0}, - {"invalid offset 16385", []string{LimitKey, OffsetKey}, []string{"1", "16385"}, true, 0, 0}, - {"invalid limit=16384 offset=16384", []string{LimitKey, OffsetKey}, []string{"16384", "16384"}, true, 0, 0}, + {"valid limit=1", []string{common.LimitKey}, []string{"1"}, false, 1, 0}, + {"valid limit=1, offset=2", []string{common.LimitKey, common.OffsetKey}, []string{"1", "2"}, false, 1, 2}, + {"valid no limit, offset=2", []string{common.OffsetKey}, []string{"2"}, false, typeutil.Unlimited, 0}, + {"invalid limit str", []string{common.LimitKey}, []string{"a"}, true, 0, 0}, + {"invalid limit zero", []string{common.LimitKey}, []string{"0"}, true, 0, 0}, + {"invalid limit negative", []string{common.LimitKey}, []string{"-1"}, true, 0, 0}, + {"invalid limit 16385", []string{common.LimitKey}, []string{"16385"}, true, 0, 0}, + {"invalid offset negative", []string{common.LimitKey, common.OffsetKey}, []string{"1", "-1"}, true, 0, 0}, + {"invalid offset 16385", []string{common.LimitKey, common.OffsetKey}, []string{"1", "16385"}, true, 0, 0}, + {"invalid limit=16384 offset=16384", []string{common.LimitKey, common.OffsetKey}, []string{"16384", "16384"}, true, 0, 0}, } for _, test := range tests { diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 63bb20d8f1..9bfab21448 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -96,36 +96,36 @@ func getPartitionIDs(ctx context.Context, collectionName string, partitionNames // parseSearchInfo returns QueryInfo and offset func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInfo, int64, error) { - topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair) + topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.TopKKey, searchParamsPair) if err != nil { - return nil, 0, errors.New(TopKKey + " not found in search_params") + return nil, 0, errors.New(common.TopKKey + " not found in search_params") } topK, err := strconv.ParseInt(topKStr, 0, 64) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) + return nil, 0, fmt.Errorf("%s [%s] is invalid", common.TopKKey, topKStr) } if err := validateLimit(topK); err != nil { - return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err) + return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", common.TopKKey, topK, err) } var offset int64 - offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair) + offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.OffsetKey, searchParamsPair) if err == nil { offset, err = strconv.ParseInt(offsetStr, 0, 64) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr) + return nil, 0, fmt.Errorf("%s [%s] is invalid", common.OffsetKey, offsetStr) } if offset != 0 { if err := validateLimit(offset); err != nil { - return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err) + return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", common.OffsetKey, offset, err) } } } queryTopK := topK + offset if err := validateLimit(queryTopK); err != nil { - return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err) + return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", common.OffsetKey, common.TopKKey, queryTopK, err) } metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair) @@ -133,20 +133,20 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn return nil, 0, errors.New(common.MetricTypeKey + " not found in search_params") } - roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair) + roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.RoundDecimalKey, searchParamsPair) if err != nil { roundDecimalStr = "-1" } roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", common.RoundDecimalKey, roundDecimalStr) } if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) { - return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", common.RoundDecimalKey, roundDecimalStr) } - searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair) + searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.SearchParamsKey, searchParamsPair) if err != nil { return nil, 0, err } @@ -244,7 +244,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { //fetch search_growing from search param var ignoreGrowing bool for i, kv := range t.request.GetSearchParams() { - if kv.GetKey() == IgnoreGrowingKey { + if kv.GetKey() == common.IgnoreGrowingKey { ignoreGrowing, err = strconv.ParseBool(kv.GetValue()) if err != nil { return errors.New("parse search growing failed") @@ -256,9 +256,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error { t.SearchRequest.IgnoreGrowing = ignoreGrowing if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { - annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) + annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(common.AnnsFieldKey, t.request.GetSearchParams()) if err != nil { - return errors.New(AnnsFieldKey + " not found in search_params") + return errors.New(common.AnnsFieldKey + " not found in search_params") } queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams()) @@ -327,7 +327,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { // Check if nq is valid: // https://milvus.io/docs/limitations.md if err := validateLimit(nq); err != nil { - return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err) + return fmt.Errorf("%s [%d] is invalid, %w", common.NQKey, nq, err) } t.SearchRequest.Nq = nq diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 69cc78ff83..8164f8d171 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -88,11 +88,11 @@ func createColl(t *testing.T, name string, rc types.RootCoord) { func getValidSearchParams() []*commonpb.KeyValuePair { return []*commonpb.KeyValuePair{ { - Key: AnnsFieldKey, + Key: common.AnnsFieldKey, Value: testFloatVecField, }, { - Key: TopKKey, + Key: common.TopKKey, Value: "10", }, { @@ -100,15 +100,15 @@ func getValidSearchParams() []*commonpb.KeyValuePair { Value: distance.L2, }, { - Key: SearchParamsKey, + Key: common.SearchParamsKey, Value: `{"nprobe": 10}`, }, { - Key: RoundDecimalKey, + Key: common.RoundDecimalKey, Value: "-1", }, { - Key: IgnoreGrowingKey, + Key: common.IgnoreGrowingKey, Value: "false", }} } @@ -230,7 +230,7 @@ func TestSearchTask_PreExecute(t *testing.T) { createColl(t, collName, rc) task := getSearchTask(t, collName) - task.request.SearchParams = getInvalidSearchParams(IgnoreGrowingKey) + task.request.SearchParams = getInvalidSearchParams(common.IgnoreGrowingKey) err = task.PreExecute(ctx) assert.Error(t, err) }) @@ -1852,7 +1852,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { sp := getValidSearchParams() sp = append(sp, &commonpb.KeyValuePair{ - Key: OffsetKey, + Key: common.OffsetKey, Value: strconv.FormatInt(targetOffset, 10), }) @@ -1864,26 +1864,26 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { t.Run("parseSearchInfo error", func(t *testing.T) { spNoTopk := []*commonpb.KeyValuePair{{ - Key: AnnsFieldKey, + Key: common.AnnsFieldKey, Value: testFloatVecField}} spInvalidTopk := append(spNoTopk, &commonpb.KeyValuePair{ - Key: TopKKey, + Key: common.TopKKey, Value: "invalid", }) spInvalidTopk65536 := append(spNoTopk, &commonpb.KeyValuePair{ - Key: TopKKey, + Key: common.TopKKey, Value: "65536", }) spNoMetricType := append(spNoTopk, &commonpb.KeyValuePair{ - Key: TopKKey, + Key: common.TopKKey, Value: "10", }) spInvalidTopkPlusOffset := append(spNoTopk, &commonpb.KeyValuePair{ - Key: OffsetKey, + Key: common.OffsetKey, Value: "65535", }) @@ -1894,32 +1894,32 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { // no roundDecimal is valid noRoundDecimal := append(spNoSearchParams, &commonpb.KeyValuePair{ - Key: SearchParamsKey, + Key: common.SearchParamsKey, Value: `{"nprobe": 10}`, }) spInvalidRoundDecimal2 := append(noRoundDecimal, &commonpb.KeyValuePair{ - Key: RoundDecimalKey, + Key: common.RoundDecimalKey, Value: "1000", }) spInvalidRoundDecimal := append(noRoundDecimal, &commonpb.KeyValuePair{ - Key: RoundDecimalKey, + Key: common.RoundDecimalKey, Value: "invalid", }) spInvalidOffsetNoInt := append(noRoundDecimal, &commonpb.KeyValuePair{ - Key: OffsetKey, + Key: common.OffsetKey, Value: "invalid", }) spInvalidOffsetNegative := append(noRoundDecimal, &commonpb.KeyValuePair{ - Key: OffsetKey, + Key: common.OffsetKey, Value: "-1", }) spInvalidOffsetTooLarge := append(noRoundDecimal, &commonpb.KeyValuePair{ - Key: OffsetKey, + Key: common.OffsetKey, Value: "16386", }) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index f331915864..3d9db2f316 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -329,19 +329,19 @@ func constructSearchRequest( Value: distance.L2, }, { - Key: SearchParamsKey, + Key: common.SearchParamsKey, Value: string(b), }, { - Key: AnnsFieldKey, + Key: common.AnnsFieldKey, Value: floatVecField, }, { - Key: TopKKey, + Key: common.TopKKey, Value: strconv.Itoa(topk), }, { - Key: RoundDecimalKey, + Key: common.RoundDecimalKey, Value: strconv.Itoa(roundDecimal), }, }, diff --git a/pkg/common/common.go b/pkg/common/common.go index faef287a92..c88b396af3 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -81,10 +81,21 @@ const ( SearchParamKey = "search_param" SegmentNumKey = "segment_num" + IgnoreGrowingKey = "ignore_growing" + AnnsFieldKey = "anns_field" + NQKey = "nq" + SearchParamsKey = "params" + RoundDecimalKey = "round_decimal" + OffsetKey = "offset" + LimitKey = "limit" + IndexParamsKey = "params" IndexTypeKey = "index_type" MetricTypeKey = "metric_type" DimKey = "dim" + + QueryNumberKey = "$query_number" + DistanceKey = "$distance" ) // Collection properties key