mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
Support ANNS statement execution (#23506)
Signed-off-by: longjiquan <jiquan.long@zilliz.com>
This commit is contained in:
parent
c013492762
commit
c46aa4c3d4
@ -5,18 +5,17 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/parser/antlrparser"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/planner"
|
||||
"github.com/milvus-io/milvus/internal/types"
|
||||
@ -75,11 +74,6 @@ func (e *defaultExecutor) execSelect(ctx context.Context, n *planner.NodeSelectS
|
||||
|
||||
q := stmt.Query.Unwrap()
|
||||
|
||||
if q.Limit.IsSome() {
|
||||
// TODO: use pagination.
|
||||
return nil, fmt.Errorf("invalid query statement, limit/offset is not supported")
|
||||
}
|
||||
|
||||
if len(q.SelectSpecs) != 0 {
|
||||
return nil, fmt.Errorf("invalid query statement, select spec is not supported")
|
||||
}
|
||||
@ -117,6 +111,16 @@ func (e *defaultExecutor) execSelect(ctx context.Context, n *planner.NodeSelectS
|
||||
|
||||
// `match` is false.
|
||||
|
||||
if q.Anns.IsSome() {
|
||||
// reuse the parsed `outputFields`.
|
||||
return e.execANNS(ctx, q, outputFields)
|
||||
}
|
||||
|
||||
if q.Limit.IsSome() {
|
||||
// TODO: use pagination.
|
||||
return nil, fmt.Errorf("invalid query statement, limit/offset is not supported")
|
||||
}
|
||||
|
||||
if !from.Where.IsSome() { // query without filter.
|
||||
return nil, fmt.Errorf("query without filter is not supported")
|
||||
}
|
||||
@ -129,42 +133,33 @@ func (e *defaultExecutor) execSelect(ctx context.Context, n *planner.NodeSelectS
|
||||
return wrapQueryResults(res), nil
|
||||
}
|
||||
|
||||
func getOutputFieldsOrMatchCountRule(fields []*planner.NodeSelectElement) (outputFields []string, match bool, err error) {
|
||||
match = false
|
||||
l := len(fields)
|
||||
|
||||
if l == 1 {
|
||||
entry := fields[0]
|
||||
match = entry.FunctionCall.IsSome() &&
|
||||
entry.FunctionCall.Unwrap().Agg.IsSome() &&
|
||||
entry.FunctionCall.Unwrap().Agg.Unwrap().AggCount.IsSome()
|
||||
func (e *defaultExecutor) execANNS(ctx context.Context, q *planner.NodeQuerySpecification, outputs []string) (*sqltypes.Result, error) {
|
||||
if !q.Limit.IsSome() {
|
||||
return nil, fmt.Errorf("limit not specified in the ANNS statement")
|
||||
}
|
||||
|
||||
if match {
|
||||
return nil, match, nil
|
||||
annsClause := q.Anns.Unwrap()
|
||||
|
||||
searchParams := prepareSearchParams(q.Limit.Unwrap(), annsClause)
|
||||
|
||||
outputsIndex, userOutputs := generateOutputsIndex(outputs)
|
||||
|
||||
filter := restoreExpr(q.From.Unwrap())
|
||||
|
||||
tableName := q.From.Unwrap().TableSources[0].TableName.Unwrap()
|
||||
|
||||
req := prepareSearchReq(tableName, filter, annsClause.Vectors, userOutputs, searchParams)
|
||||
|
||||
res, err := e.s.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputFields = make([]string, 0, l)
|
||||
for _, entry := range fields {
|
||||
if entry.Star.IsSome() {
|
||||
// TODO: support `select *`.
|
||||
return nil, match, fmt.Errorf("* is not supported")
|
||||
}
|
||||
|
||||
if entry.FunctionCall.IsSome() {
|
||||
return nil, match, fmt.Errorf("combined select elements is not supported")
|
||||
}
|
||||
|
||||
if entry.FullColumnName.IsSome() {
|
||||
c := entry.FullColumnName.Unwrap()
|
||||
if c.Alias.IsSome() {
|
||||
return nil, match, fmt.Errorf("alias for select elements is not supported")
|
||||
}
|
||||
outputFields = append(outputFields, c.Name)
|
||||
}
|
||||
if res.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
||||
return nil, common.NewStatusError(res.GetStatus().GetErrorCode(), res.GetStatus().GetReason())
|
||||
}
|
||||
|
||||
return outputFields, false, nil
|
||||
return wrapSearchResult(res, outputsIndex, userOutputs), nil
|
||||
}
|
||||
|
||||
func (e *defaultExecutor) execCountWithFilter(ctx context.Context, tableName string, filter string) (*sqltypes.Result, error) {
|
||||
@ -229,126 +224,6 @@ func (e *defaultExecutor) execQuery(ctx context.Context, tableName string, filte
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func wrapCountResult(rowCnt int, column string) *sqltypes.Result {
|
||||
result1 := &sqltypes.Result{
|
||||
Fields: []*querypb.Field{
|
||||
{
|
||||
Name: column,
|
||||
Type: querypb.Type_INT64,
|
||||
},
|
||||
},
|
||||
Rows: [][]sqltypes.Value{
|
||||
{
|
||||
sqltypes.NewInt64(int64(rowCnt)),
|
||||
},
|
||||
},
|
||||
}
|
||||
return result1
|
||||
}
|
||||
|
||||
func wrapQueryResults(res *milvuspb.QueryResults) *sqltypes.Result {
|
||||
fieldsData := res.GetFieldsData()
|
||||
nColumn := len(fieldsData)
|
||||
fields := make([]*querypb.Field, 0, nColumn)
|
||||
|
||||
if nColumn <= 0 {
|
||||
return &sqltypes.Result{}
|
||||
}
|
||||
|
||||
for i := 0; i < nColumn; i++ {
|
||||
fields = append(fields, getSQLField(res.GetCollectionName(), fieldsData[i]))
|
||||
}
|
||||
|
||||
nRow := typeutil.GetRowCount(fieldsData[0])
|
||||
rows := make([][]sqltypes.Value, 0, nRow)
|
||||
for i := 0; i < nRow; i++ {
|
||||
row := make([]sqltypes.Value, 0, nColumn)
|
||||
for j := 0; j < nColumn; j++ {
|
||||
row = append(row, getDataSingle(fieldsData[j], i))
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
|
||||
return &sqltypes.Result{
|
||||
Fields: fields,
|
||||
Rows: rows,
|
||||
}
|
||||
}
|
||||
|
||||
func getSQLField(tableName string, fieldData *schemapb.FieldData) *querypb.Field {
|
||||
return &querypb.Field{
|
||||
Name: fieldData.GetFieldName(),
|
||||
Type: toSQLType(fieldData.GetType()),
|
||||
Table: tableName,
|
||||
OrgTable: "",
|
||||
Database: "",
|
||||
OrgName: "",
|
||||
ColumnLength: 0,
|
||||
Charset: 0,
|
||||
Decimals: 0,
|
||||
Flags: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func toSQLType(t schemapb.DataType) querypb.Type {
|
||||
switch t {
|
||||
case schemapb.DataType_Bool:
|
||||
// TODO: tinyint
|
||||
return querypb.Type_UINT8
|
||||
case schemapb.DataType_Int8:
|
||||
return querypb.Type_INT8
|
||||
case schemapb.DataType_Int16:
|
||||
return querypb.Type_INT16
|
||||
case schemapb.DataType_Int32:
|
||||
return querypb.Type_INT32
|
||||
case schemapb.DataType_Int64:
|
||||
return querypb.Type_INT64
|
||||
case schemapb.DataType_Float:
|
||||
return querypb.Type_FLOAT32
|
||||
case schemapb.DataType_Double:
|
||||
return querypb.Type_FLOAT64
|
||||
case schemapb.DataType_VarChar:
|
||||
return querypb.Type_VARCHAR
|
||||
// TODO: vector.
|
||||
default:
|
||||
return querypb.Type_NULL_TYPE
|
||||
}
|
||||
}
|
||||
|
||||
func getDataSingle(fieldData *schemapb.FieldData, idx int) sqltypes.Value {
|
||||
switch fieldData.GetType() {
|
||||
case schemapb.DataType_Bool:
|
||||
// TODO: tinyint
|
||||
return sqltypes.NewInt32(1)
|
||||
case schemapb.DataType_Int8:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int8, strconv.AppendInt(nil, int64(v), 10))
|
||||
case schemapb.DataType_Int16:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int16, strconv.AppendInt(nil, int64(v), 10))
|
||||
case schemapb.DataType_Int32:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int32, strconv.AppendInt(nil, int64(v), 10))
|
||||
case schemapb.DataType_Int64:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_LongData).LongData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int64, strconv.AppendInt(nil, v, 10))
|
||||
case schemapb.DataType_Float:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_FloatData).FloatData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64))
|
||||
case schemapb.DataType_Double:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_DoubleData).DoubleData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Float64, strconv.AppendFloat(nil, v, 'g', -1, 64))
|
||||
case schemapb.DataType_VarChar:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_StringData).StringData.GetData()[idx]
|
||||
return sqltypes.NewVarChar(v)
|
||||
|
||||
// TODO: vector.
|
||||
default:
|
||||
// TODO: should raise error here.
|
||||
return sqltypes.NewInt32(1)
|
||||
}
|
||||
}
|
||||
|
||||
func NewDefaultExecutor(s types.ProxyComponent) Executor {
|
||||
return &defaultExecutor{s: s}
|
||||
}
|
||||
|
||||
@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/parser/antlrparser"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
@ -445,63 +447,6 @@ func Test_defaultExecutor_execSelect(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getOutputFieldsOrMatchCountRule(t *testing.T) {
|
||||
t.Run("match count rule", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFunctionCall(
|
||||
planner.NewNodeFunctionCall("", planner.WithAgg(
|
||||
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
|
||||
planner.NewNodeCount(""))))))),
|
||||
}
|
||||
_, match, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, match)
|
||||
})
|
||||
|
||||
t.Run("star *, not supported", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithStar()),
|
||||
}
|
||||
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("combined", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFunctionCall(
|
||||
planner.NewNodeFunctionCall("", planner.WithAgg(
|
||||
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
|
||||
planner.NewNodeCount(""))))))),
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field"))),
|
||||
}
|
||||
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("alias, not supported", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field", planner.FullColumnNameWithAlias("alias")))),
|
||||
}
|
||||
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field1"))),
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field2"))),
|
||||
}
|
||||
outputFields, match, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, match)
|
||||
assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_defaultExecutor_execCountWithFilter(t *testing.T) {
|
||||
t.Run("failed to query", func(t *testing.T) {
|
||||
s := mocks.NewProxyComponent(t)
|
||||
@ -593,129 +538,67 @@ func Test_defaultExecutor_execQuery(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func Test_wrapCountResult(t *testing.T) {
|
||||
sqlRes := wrapCountResult(100, "count(*)")
|
||||
assert.Equal(t, 1, len(sqlRes.Fields))
|
||||
assert.Equal(t, "count(*)", sqlRes.Fields[0].Name)
|
||||
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
|
||||
assert.Equal(t, 1, len(sqlRes.Rows))
|
||||
assert.Equal(t, 1, len(sqlRes.Rows[0]))
|
||||
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
|
||||
}
|
||||
|
||||
func Test_wrapQueryResults(t *testing.T) {
|
||||
res := &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "field",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
func Test_defaultExecutor_execANNS(t *testing.T) {
|
||||
f1 := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "pk",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{6, 5, 4, 3, 2, 1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
f2 := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: "random",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: []float32{6.6, 5.5, 4.4, 3.3, 2.2, 1.1},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
CollectionName: "test",
|
||||
}
|
||||
sqlRes := wrapQueryResults(res)
|
||||
assert.Equal(t, 1, len(sqlRes.Fields))
|
||||
assert.Equal(t, 4, len(sqlRes.Rows))
|
||||
assert.Equal(t, "field", sqlRes.Fields[0].Name)
|
||||
assert.Equal(t, querypb.Type_INT64, sqlRes.Fields[0].Type)
|
||||
assert.Equal(t, 1, len(sqlRes.Rows[0]))
|
||||
assert.Equal(t, querypb.Type_INT64, sqlRes.Rows[0][0].Type())
|
||||
}
|
||||
|
||||
func Test_getSQLField(t *testing.T) {
|
||||
f := &schemapb.FieldData{
|
||||
FieldName: "a",
|
||||
Type: schemapb.DataType_Int64,
|
||||
res := &milvuspb.SearchResults{
|
||||
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
|
||||
Results: &schemapb.SearchResultData{
|
||||
NumQueries: 2,
|
||||
TopK: 3,
|
||||
FieldsData: []*schemapb.FieldData{f1, f2},
|
||||
Scores: []float32{1.1, 2.2, 3.3, 4.4, 5.5, 6.6},
|
||||
Topks: []int64{2, 2},
|
||||
},
|
||||
CollectionName: "hello_milvus",
|
||||
}
|
||||
sf := getSQLField("t", f)
|
||||
assert.Equal(t, "a", sf.Name)
|
||||
assert.Equal(t, querypb.Type_INT64, sf.Type)
|
||||
assert.Equal(t, "t", sf.Table)
|
||||
}
|
||||
s := mocks.NewProxyComponent(t)
|
||||
s.On("Search",
|
||||
mock.Anything, // context.Context
|
||||
mock.Anything, // *milvuspb.SearchRequest
|
||||
).Return(res, nil)
|
||||
|
||||
func Test_toSQLType(t *testing.T) {
|
||||
type args struct {
|
||||
t schemapb.DataType
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want querypb.Type
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Bool,
|
||||
},
|
||||
want: querypb.Type_UINT8,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int8,
|
||||
},
|
||||
want: querypb.Type_INT8,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int16,
|
||||
},
|
||||
want: querypb.Type_INT16,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int32,
|
||||
},
|
||||
want: querypb.Type_INT32,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int64,
|
||||
},
|
||||
want: querypb.Type_INT64,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Float,
|
||||
},
|
||||
want: querypb.Type_FLOAT32,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Double,
|
||||
},
|
||||
want: querypb.Type_FLOAT64,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_VarChar,
|
||||
},
|
||||
want: querypb.Type_VARCHAR,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_FloatVector,
|
||||
},
|
||||
want: querypb.Type_NULL_TYPE,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_BinaryVector,
|
||||
},
|
||||
want: querypb.Type_NULL_TYPE,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t)
|
||||
})
|
||||
}
|
||||
sql := `
|
||||
select
|
||||
$query_number, pk, random, $distance
|
||||
from hello_milvus
|
||||
where random > 0.5
|
||||
anns by embeddings -> ([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], [0.8, 0.7, 0.6, 0.6, 0.4, 0.3, 0.2, 0.1])
|
||||
params = (metric_type=L2, nprobe=10)
|
||||
limit 3
|
||||
`
|
||||
plan, warns, err := antlrparser.NewAntlrParser().Parse(sql)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, warns)
|
||||
|
||||
e := NewDefaultExecutor(s).(*defaultExecutor)
|
||||
_, err = e.execANNS(context.TODO(),
|
||||
antlrparser.GetSqlStatements(plan.Node).Statements[0].DmlStatement.Unwrap().SelectStatement.Unwrap().SimpleSelect.Unwrap().Query.Unwrap(),
|
||||
[]string{"$query_number", "pk", "random", "$distance"})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
363
internal/mysqld/executor/utils.go
Normal file
363
internal/mysqld/executor/utils.go
Normal file
@ -0,0 +1,363 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
querypb "github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
|
||||
"github.com/xelabs/go-mysqlstack/sqlparser/depends/sqltypes"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/mysqld/planner"
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
func prepareSearchParams(limitClause *planner.NodeLimitClause, annsClause *planner.NodeANNSClause) map[string]string {
|
||||
searchParams := map[string]string{
|
||||
common.AnnsFieldKey: annsClause.Column.Name,
|
||||
common.TopKKey: fmt.Sprintf("%d", limitClause.Limit),
|
||||
common.OffsetKey: fmt.Sprintf("%d", limitClause.Offset),
|
||||
}
|
||||
|
||||
if annsClause.Params.IsSome() {
|
||||
for k, v := range annsClause.Params.Unwrap().KVs {
|
||||
searchParams[k] = v
|
||||
}
|
||||
params, _ := json.Marshal(annsClause.Params.Unwrap().KVs)
|
||||
searchParams[common.SearchParamsKey] = string(params)
|
||||
}
|
||||
|
||||
return searchParams
|
||||
}
|
||||
|
||||
func generateOutputsIndex(outputs []string) (index map[string]int, users []string) {
|
||||
index = make(map[string]int, len(outputs))
|
||||
users = make([]string, 0, len(outputs))
|
||||
|
||||
fix := func(s string) string {
|
||||
return strings.ToLower(strings.TrimSpace(s))
|
||||
}
|
||||
|
||||
for i, f := range outputs {
|
||||
fixed := fix(f)
|
||||
if fixed != common.QueryNumberKey && fixed != common.DistanceKey {
|
||||
users = append(users, f)
|
||||
index[f] = i
|
||||
} else {
|
||||
index[fixed] = i
|
||||
}
|
||||
}
|
||||
|
||||
return index, users
|
||||
}
|
||||
|
||||
func restoreExpr(from *planner.NodeFromClause) string {
|
||||
filter := ""
|
||||
if from.Where.IsSome() {
|
||||
filter = planner.NewExprTextRestorer().RestoreExprText(from.Where.Unwrap())
|
||||
}
|
||||
return filter
|
||||
}
|
||||
|
||||
func vector2PlaceholderGroupBytes(vectors []*planner.NodeVector) []byte {
|
||||
phg := &commonpb.PlaceholderGroup{
|
||||
Placeholders: []*commonpb.PlaceholderValue{
|
||||
vector2Placeholder(vectors),
|
||||
},
|
||||
}
|
||||
|
||||
bs, _ := proto.Marshal(phg)
|
||||
return bs
|
||||
}
|
||||
|
||||
func vector2Placeholder(vectors []*planner.NodeVector) *commonpb.PlaceholderValue {
|
||||
var placeHolderType commonpb.PlaceholderType
|
||||
|
||||
ph := &commonpb.PlaceholderValue{
|
||||
Tag: "$0",
|
||||
Values: make([][]byte, 0, len(vectors)),
|
||||
}
|
||||
|
||||
if len(vectors) == 0 {
|
||||
return ph
|
||||
}
|
||||
|
||||
if vectors[0].FloatVector.IsSome() {
|
||||
placeHolderType = commonpb.PlaceholderType_FloatVector
|
||||
} else {
|
||||
placeHolderType = commonpb.PlaceholderType_BinaryVector
|
||||
|
||||
}
|
||||
|
||||
ph.Type = placeHolderType
|
||||
for _, vector := range vectors {
|
||||
ph.Values = append(ph.Values, vector.Serialize())
|
||||
}
|
||||
|
||||
return ph
|
||||
}
|
||||
|
||||
func prepareSearchReq(tableName string, filter string, vectors []*planner.NodeVector, userOutputs []string, searchParams map[string]string) *milvuspb.SearchRequest {
|
||||
phg := vector2PlaceholderGroupBytes(vectors)
|
||||
req := &milvuspb.SearchRequest{
|
||||
CollectionName: tableName,
|
||||
PartitionNames: nil,
|
||||
Dsl: filter,
|
||||
PlaceholderGroup: phg,
|
||||
DslType: commonpb.DslType_BoolExprV1,
|
||||
OutputFields: userOutputs,
|
||||
SearchParams: funcutil.Map2KeyValuePair(searchParams),
|
||||
TravelTimestamp: 0,
|
||||
GuaranteeTimestamp: 2, // default bounded consistency level.
|
||||
Nq: int64(len(vectors)),
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func wrapFieldsData(collectionName string, fieldsData []*schemapb.FieldData) *sqltypes.Result {
|
||||
nColumn := len(fieldsData)
|
||||
fields := make([]*querypb.Field, 0, nColumn)
|
||||
|
||||
if nColumn <= 0 {
|
||||
return &sqltypes.Result{}
|
||||
}
|
||||
|
||||
for i := 0; i < nColumn; i++ {
|
||||
fields = append(fields, getSQLField(collectionName, fieldsData[i]))
|
||||
}
|
||||
|
||||
nRow := typeutil.GetRowCount(fieldsData[0])
|
||||
rows := make([][]sqltypes.Value, 0, nRow)
|
||||
for i := 0; i < nRow; i++ {
|
||||
row := make([]sqltypes.Value, 0, nColumn)
|
||||
for j := 0; j < nColumn; j++ {
|
||||
row = append(row, getDataSingle(fieldsData[j], i))
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
|
||||
return &sqltypes.Result{
|
||||
Fields: fields,
|
||||
Rows: rows,
|
||||
}
|
||||
}
|
||||
|
||||
func wrapQueryResults(res *milvuspb.QueryResults) *sqltypes.Result {
|
||||
return wrapFieldsData(res.GetCollectionName(), res.GetFieldsData())
|
||||
}
|
||||
|
||||
func wrapCountResult(rowCnt int, column string) *sqltypes.Result {
|
||||
result1 := &sqltypes.Result{
|
||||
Fields: []*querypb.Field{
|
||||
{
|
||||
Name: column,
|
||||
Type: querypb.Type_INT64,
|
||||
},
|
||||
},
|
||||
Rows: [][]sqltypes.Value{
|
||||
{
|
||||
sqltypes.NewInt64(int64(rowCnt)),
|
||||
},
|
||||
},
|
||||
}
|
||||
return result1
|
||||
}
|
||||
|
||||
func getSQLField(tableName string, fieldData *schemapb.FieldData) *querypb.Field {
|
||||
return &querypb.Field{
|
||||
Name: fieldData.GetFieldName(),
|
||||
Type: toSQLType(fieldData.GetType()),
|
||||
Table: tableName,
|
||||
OrgTable: "",
|
||||
Database: "",
|
||||
OrgName: "",
|
||||
ColumnLength: 0,
|
||||
Charset: 0,
|
||||
Decimals: 0,
|
||||
Flags: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func toSQLType(t schemapb.DataType) querypb.Type {
|
||||
switch t {
|
||||
case schemapb.DataType_Bool:
|
||||
// TODO: tinyint
|
||||
return querypb.Type_UINT8
|
||||
case schemapb.DataType_Int8:
|
||||
return querypb.Type_INT8
|
||||
case schemapb.DataType_Int16:
|
||||
return querypb.Type_INT16
|
||||
case schemapb.DataType_Int32:
|
||||
return querypb.Type_INT32
|
||||
case schemapb.DataType_Int64:
|
||||
return querypb.Type_INT64
|
||||
case schemapb.DataType_Float:
|
||||
return querypb.Type_FLOAT32
|
||||
case schemapb.DataType_Double:
|
||||
return querypb.Type_FLOAT64
|
||||
case schemapb.DataType_VarChar:
|
||||
return querypb.Type_VARCHAR
|
||||
// TODO: vector.
|
||||
default:
|
||||
return querypb.Type_NULL_TYPE
|
||||
}
|
||||
}
|
||||
|
||||
func getDataSingle(fieldData *schemapb.FieldData, idx int) sqltypes.Value {
|
||||
switch fieldData.GetType() {
|
||||
case schemapb.DataType_Bool:
|
||||
// TODO: tinyint
|
||||
return sqltypes.NewInt32(1)
|
||||
case schemapb.DataType_Int8:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int8, strconv.AppendInt(nil, int64(v), 10))
|
||||
case schemapb.DataType_Int16:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int16, strconv.AppendInt(nil, int64(v), 10))
|
||||
case schemapb.DataType_Int32:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_IntData).IntData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int32, strconv.AppendInt(nil, int64(v), 10))
|
||||
case schemapb.DataType_Int64:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_LongData).LongData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Int64, strconv.AppendInt(nil, v, 10))
|
||||
case schemapb.DataType_Float:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_FloatData).FloatData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Float32, strconv.AppendFloat(nil, float64(v), 'f', -1, 64))
|
||||
case schemapb.DataType_Double:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_DoubleData).DoubleData.GetData()[idx]
|
||||
return sqltypes.MakeTrusted(sqltypes.Float64, strconv.AppendFloat(nil, v, 'g', -1, 64))
|
||||
case schemapb.DataType_VarChar:
|
||||
v := fieldData.Field.(*schemapb.FieldData_Scalars).Scalars.Data.(*schemapb.ScalarField_StringData).StringData.GetData()[idx]
|
||||
return sqltypes.NewVarChar(v)
|
||||
|
||||
// TODO: vector.
|
||||
default:
|
||||
// TODO: should raise error here.
|
||||
return sqltypes.NewInt32(1)
|
||||
}
|
||||
}
|
||||
|
||||
func wrapSearchResult(res *milvuspb.SearchResults, outputsIndex map[string]int, userOutputs []string) *sqltypes.Result {
|
||||
qnIndex, qnOk := outputsIndex[common.QueryNumberKey]
|
||||
disIndex, disOk := outputsIndex[common.DistanceKey]
|
||||
|
||||
fieldsData := make([]*schemapb.FieldData, len(res.GetResults().GetFieldsData()))
|
||||
copy(fieldsData, res.GetResults().GetFieldsData())
|
||||
|
||||
insert := func(fields []*schemapb.FieldData, field *schemapb.FieldData, index int) []*schemapb.FieldData {
|
||||
if index >= len(fields) {
|
||||
return append(fields, field)
|
||||
}
|
||||
return append(fieldsData[:index], append([]*schemapb.FieldData{field}, fieldsData[index:]...)...)
|
||||
}
|
||||
|
||||
if qnOk && disOk {
|
||||
qnField := generateQueryNumberFieldData(res)
|
||||
disField := generateDistanceFieldsData(res)
|
||||
|
||||
index1, index2 := qnIndex, disIndex
|
||||
f1, f2 := qnField, disField
|
||||
if qnIndex > disIndex {
|
||||
index1, index2 = disIndex, qnIndex
|
||||
f1, f2 = disField, qnField
|
||||
}
|
||||
|
||||
fieldsData = insert(fieldsData, f1, index1)
|
||||
fieldsData = insert(fieldsData, f2, index2)
|
||||
} else if qnOk {
|
||||
qnField := generateQueryNumberFieldData(res)
|
||||
fieldsData = insert(fieldsData, qnField, qnIndex)
|
||||
} else if disOk {
|
||||
disField := generateDistanceFieldsData(res)
|
||||
fieldsData = insert(fieldsData, disField, disIndex)
|
||||
}
|
||||
|
||||
return wrapFieldsData(res.GetCollectionName(), fieldsData)
|
||||
}
|
||||
|
||||
func generateQueryNumberFieldData(res *milvuspb.SearchResults) *schemapb.FieldData {
|
||||
arr := make([]int64, 0, res.GetResults().GetNumQueries()*res.GetResults().GetTopK())
|
||||
|
||||
for i, topk := range res.GetResults().GetTopks() {
|
||||
for j := 0; int64(j) < topk; j++ {
|
||||
arr = append(arr, int64(i))
|
||||
}
|
||||
}
|
||||
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: common.QueryNumberKey,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: arr,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func generateDistanceFieldsData(res *milvuspb.SearchResults) *schemapb.FieldData {
|
||||
return &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float,
|
||||
FieldName: common.DistanceKey,
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_FloatData{
|
||||
FloatData: &schemapb.FloatArray{
|
||||
Data: res.GetResults().GetScores(),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getOutputFieldsOrMatchCountRule(fields []*planner.NodeSelectElement) (outputFields []string, match bool, err error) {
|
||||
match = false
|
||||
l := len(fields)
|
||||
|
||||
if l == 1 {
|
||||
entry := fields[0]
|
||||
match = entry.FunctionCall.IsSome() &&
|
||||
entry.FunctionCall.Unwrap().Agg.IsSome() &&
|
||||
entry.FunctionCall.Unwrap().Agg.Unwrap().AggCount.IsSome()
|
||||
}
|
||||
|
||||
if match {
|
||||
return nil, match, nil
|
||||
}
|
||||
|
||||
outputFields = make([]string, 0, l)
|
||||
for _, entry := range fields {
|
||||
if entry.Star.IsSome() {
|
||||
// TODO: support `select *`.
|
||||
return nil, match, fmt.Errorf("* is not supported")
|
||||
}
|
||||
|
||||
if entry.FunctionCall.IsSome() {
|
||||
return nil, match, fmt.Errorf("combined select elements is not supported")
|
||||
}
|
||||
|
||||
if entry.FullColumnName.IsSome() {
|
||||
c := entry.FullColumnName.Unwrap()
|
||||
if c.Alias.IsSome() {
|
||||
return nil, match, fmt.Errorf("alias for select elements is not supported")
|
||||
}
|
||||
outputFields = append(outputFields, c.Name)
|
||||
}
|
||||
}
|
||||
|
||||
return outputFields, false, nil
|
||||
}
|
||||
@ -1,6 +1,15 @@
|
||||
package executor
|
||||
|
||||
import "github.com/milvus-io/milvus/internal/mysqld/planner"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/mysqld/planner"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/xelabs/go-mysqlstack/sqlparser/depends/query"
|
||||
)
|
||||
|
||||
func GenNodeExpression(field string, c int64, op planner.ComparisonOperator) *planner.NodeExpression {
|
||||
n := planner.NewNodeExpression("", planner.WithPredicate(
|
||||
@ -20,3 +29,187 @@ func GenNodeExpression(field string, c int64, op planner.ComparisonOperator) *pl
|
||||
)
|
||||
return n
|
||||
}
|
||||
|
||||
func Test_getOutputFieldsOrMatchCountRule(t *testing.T) {
|
||||
t.Run("match count rule", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFunctionCall(
|
||||
planner.NewNodeFunctionCall("", planner.WithAgg(
|
||||
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
|
||||
planner.NewNodeCount(""))))))),
|
||||
}
|
||||
_, match, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, match)
|
||||
})
|
||||
|
||||
t.Run("star *, not supported", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithStar()),
|
||||
}
|
||||
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("combined", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFunctionCall(
|
||||
planner.NewNodeFunctionCall("", planner.WithAgg(
|
||||
planner.NewNodeAggregateWindowedFunction("", planner.WithAggCount(
|
||||
planner.NewNodeCount(""))))))),
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field"))),
|
||||
}
|
||||
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("alias, not supported", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field", planner.FullColumnNameWithAlias("alias")))),
|
||||
}
|
||||
_, _, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("normal case", func(t *testing.T) {
|
||||
fl := []*planner.NodeSelectElement{
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field1"))),
|
||||
planner.NewNodeSelectElement("", planner.WithFullColumnName(
|
||||
planner.NewNodeFullColumnName("", "field2"))),
|
||||
}
|
||||
outputFields, match, err := getOutputFieldsOrMatchCountRule(fl)
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, match)
|
||||
assert.ElementsMatch(t, []string{"field1", "field2"}, outputFields)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_wrapCountResult(t *testing.T) {
|
||||
sqlRes := wrapCountResult(100, "count(*)")
|
||||
assert.Equal(t, 1, len(sqlRes.Fields))
|
||||
assert.Equal(t, "count(*)", sqlRes.Fields[0].Name)
|
||||
assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type)
|
||||
assert.Equal(t, 1, len(sqlRes.Rows))
|
||||
assert.Equal(t, 1, len(sqlRes.Rows[0]))
|
||||
assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type())
|
||||
}
|
||||
|
||||
func Test_wrapQueryResults(t *testing.T) {
|
||||
res := &milvuspb.QueryResults{
|
||||
Status: &commonpb.Status{},
|
||||
FieldsData: []*schemapb.FieldData{
|
||||
{
|
||||
Type: schemapb.DataType_Int64,
|
||||
FieldName: "field",
|
||||
Field: &schemapb.FieldData_Scalars{
|
||||
Scalars: &schemapb.ScalarField{
|
||||
Data: &schemapb.ScalarField_LongData{
|
||||
LongData: &schemapb.LongArray{
|
||||
Data: []int64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
CollectionName: "test",
|
||||
}
|
||||
sqlRes := wrapQueryResults(res)
|
||||
assert.Equal(t, 1, len(sqlRes.Fields))
|
||||
assert.Equal(t, 4, len(sqlRes.Rows))
|
||||
assert.Equal(t, "field", sqlRes.Fields[0].Name)
|
||||
assert.Equal(t, query.Type_INT64, sqlRes.Fields[0].Type)
|
||||
assert.Equal(t, 1, len(sqlRes.Rows[0]))
|
||||
assert.Equal(t, query.Type_INT64, sqlRes.Rows[0][0].Type())
|
||||
}
|
||||
|
||||
func Test_getSQLField(t *testing.T) {
|
||||
f := &schemapb.FieldData{
|
||||
FieldName: "a",
|
||||
Type: schemapb.DataType_Int64,
|
||||
}
|
||||
sf := getSQLField("t", f)
|
||||
assert.Equal(t, "a", sf.Name)
|
||||
assert.Equal(t, query.Type_INT64, sf.Type)
|
||||
assert.Equal(t, "t", sf.Table)
|
||||
}
|
||||
|
||||
func Test_toSQLType(t *testing.T) {
|
||||
type args struct {
|
||||
t schemapb.DataType
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want query.Type
|
||||
}{
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Bool,
|
||||
},
|
||||
want: query.Type_UINT8,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int8,
|
||||
},
|
||||
want: query.Type_INT8,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int16,
|
||||
},
|
||||
want: query.Type_INT16,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int32,
|
||||
},
|
||||
want: query.Type_INT32,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Int64,
|
||||
},
|
||||
want: query.Type_INT64,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Float,
|
||||
},
|
||||
want: query.Type_FLOAT32,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_Double,
|
||||
},
|
||||
want: query.Type_FLOAT64,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_VarChar,
|
||||
},
|
||||
want: query.Type_VARCHAR,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_FloatVector,
|
||||
},
|
||||
want: query.Type_NULL_TYPE,
|
||||
},
|
||||
{
|
||||
args: args{
|
||||
t: schemapb.DataType_BinaryVector,
|
||||
},
|
||||
want: query.Type_NULL_TYPE,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, toSQLType(tt.args.t), "toSQLType(%v)", tt.args.t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,9 +1,25 @@
|
||||
package planner
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
)
|
||||
|
||||
type NodeFloatVector struct {
|
||||
Array []float32
|
||||
}
|
||||
|
||||
func (n NodeFloatVector) Serialize() []byte {
|
||||
data := make([]byte, 0, 4*len(n.Array)) // float32 occupies 4 bytes
|
||||
buf := make([]byte, 4)
|
||||
for _, f := range n.Array {
|
||||
common.Endian.PutUint32(buf, math.Float32bits(f))
|
||||
data = append(data, buf...)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func NewNodeFloatVector(arr []float32) *NodeFloatVector {
|
||||
return &NodeFloatVector{
|
||||
Array: arr,
|
||||
|
||||
@ -14,6 +14,13 @@ func (n *NodeVector) apply(opts ...NodeVectorOption) {
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NodeVector) Serialize() []byte {
|
||||
if n.FloatVector.IsSome() {
|
||||
return n.FloatVector.Unwrap().Serialize()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func WithFloatVector(v *NodeFloatVector) NodeVectorOption {
|
||||
return func(n *NodeVector) {
|
||||
n.FloatVector = optional.Some(v)
|
||||
|
||||
@ -42,16 +42,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
IgnoreGrowingKey = "ignore_growing"
|
||||
AnnsFieldKey = "anns_field"
|
||||
TopKKey = "topk"
|
||||
NQKey = "nq"
|
||||
MetricTypeKey = "metric_type"
|
||||
SearchParamsKey = "params"
|
||||
RoundDecimalKey = "round_decimal"
|
||||
OffsetKey = "offset"
|
||||
LimitKey = "limit"
|
||||
|
||||
InsertTaskName = "InsertTask"
|
||||
CreateCollectionTaskName = "CreateCollectionTask"
|
||||
DropCollectionTaskName = "DropCollectionTask"
|
||||
|
||||
@ -21,6 +21,8 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/common"
|
||||
|
||||
"github.com/cockroachdb/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@ -287,7 +289,7 @@ func Test_parseIndexParams(t *testing.T) {
|
||||
Value: "HNSW",
|
||||
},
|
||||
{
|
||||
Key: MetricTypeKey,
|
||||
Key: common.MetricTypeKey,
|
||||
Value: "IP",
|
||||
},
|
||||
{
|
||||
@ -321,7 +323,7 @@ func Test_parseIndexParams(t *testing.T) {
|
||||
Value: "128",
|
||||
},
|
||||
{
|
||||
Key: MetricTypeKey,
|
||||
Key: common.MetricTypeKey,
|
||||
Value: "L2",
|
||||
},
|
||||
}},
|
||||
@ -338,7 +340,7 @@ func Test_parseIndexParams(t *testing.T) {
|
||||
Value: "HNSW",
|
||||
},
|
||||
{
|
||||
Key: MetricTypeKey,
|
||||
Key: common.MetricTypeKey,
|
||||
Value: "IP",
|
||||
},
|
||||
{
|
||||
|
||||
@ -134,32 +134,32 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
|
||||
err error
|
||||
)
|
||||
|
||||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
|
||||
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.LimitKey, queryParamsPair)
|
||||
// if limit is not provided
|
||||
if err != nil {
|
||||
return &queryParams{limit: typeutil.Unlimited}, nil
|
||||
}
|
||||
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", common.LimitKey, limitStr)
|
||||
}
|
||||
if limit != 0 {
|
||||
if err := validateLimit(limit); err != nil {
|
||||
return nil, fmt.Errorf("%s [%d] is invalid, %w", LimitKey, limit, err)
|
||||
return nil, fmt.Errorf("%s [%d] is invalid, %w", common.LimitKey, limit, err)
|
||||
}
|
||||
}
|
||||
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, queryParamsPair)
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.OffsetKey, queryParamsPair)
|
||||
// if offset is provided
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
return nil, fmt.Errorf("%s [%s] is invalid", common.OffsetKey, offsetStr)
|
||||
}
|
||||
|
||||
if offset != 0 {
|
||||
if err := validateLimit(offset); err != nil {
|
||||
return nil, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
||||
return nil, fmt.Errorf("%s [%d] is invalid, %w", common.OffsetKey, offset, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -289,7 +289,7 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
|
||||
//fetch search_growing from search param
|
||||
var ignoreGrowing bool
|
||||
for i, kv := range t.request.GetQueryParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
if kv.GetKey() == common.IgnoreGrowingKey {
|
||||
ignoreGrowing, err = strconv.ParseBool(kv.Value)
|
||||
if err != nil {
|
||||
return errors.New("parse search growing failed")
|
||||
|
||||
@ -154,7 +154,7 @@ func TestQueryTask_all(t *testing.T) {
|
||||
Expr: expr,
|
||||
QueryParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Key: common.IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
},
|
||||
@ -392,16 +392,16 @@ func TestTaskQuery_functions(t *testing.T) {
|
||||
outOffset int64
|
||||
}{
|
||||
{"empty input", []string{}, []string{}, false, typeutil.Unlimited, 0},
|
||||
{"valid limit=1", []string{LimitKey}, []string{"1"}, false, 1, 0},
|
||||
{"valid limit=1, offset=2", []string{LimitKey, OffsetKey}, []string{"1", "2"}, false, 1, 2},
|
||||
{"valid no limit, offset=2", []string{OffsetKey}, []string{"2"}, false, typeutil.Unlimited, 0},
|
||||
{"invalid limit str", []string{LimitKey}, []string{"a"}, true, 0, 0},
|
||||
{"invalid limit zero", []string{LimitKey}, []string{"0"}, true, 0, 0},
|
||||
{"invalid limit negative", []string{LimitKey}, []string{"-1"}, true, 0, 0},
|
||||
{"invalid limit 16385", []string{LimitKey}, []string{"16385"}, true, 0, 0},
|
||||
{"invalid offset negative", []string{LimitKey, OffsetKey}, []string{"1", "-1"}, true, 0, 0},
|
||||
{"invalid offset 16385", []string{LimitKey, OffsetKey}, []string{"1", "16385"}, true, 0, 0},
|
||||
{"invalid limit=16384 offset=16384", []string{LimitKey, OffsetKey}, []string{"16384", "16384"}, true, 0, 0},
|
||||
{"valid limit=1", []string{common.LimitKey}, []string{"1"}, false, 1, 0},
|
||||
{"valid limit=1, offset=2", []string{common.LimitKey, common.OffsetKey}, []string{"1", "2"}, false, 1, 2},
|
||||
{"valid no limit, offset=2", []string{common.OffsetKey}, []string{"2"}, false, typeutil.Unlimited, 0},
|
||||
{"invalid limit str", []string{common.LimitKey}, []string{"a"}, true, 0, 0},
|
||||
{"invalid limit zero", []string{common.LimitKey}, []string{"0"}, true, 0, 0},
|
||||
{"invalid limit negative", []string{common.LimitKey}, []string{"-1"}, true, 0, 0},
|
||||
{"invalid limit 16385", []string{common.LimitKey}, []string{"16385"}, true, 0, 0},
|
||||
{"invalid offset negative", []string{common.LimitKey, common.OffsetKey}, []string{"1", "-1"}, true, 0, 0},
|
||||
{"invalid offset 16385", []string{common.LimitKey, common.OffsetKey}, []string{"1", "16385"}, true, 0, 0},
|
||||
{"invalid limit=16384 offset=16384", []string{common.LimitKey, common.OffsetKey}, []string{"16384", "16384"}, true, 0, 0},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@ -96,36 +96,36 @@ func getPartitionIDs(ctx context.Context, collectionName string, partitionNames
|
||||
|
||||
// parseSearchInfo returns QueryInfo and offset
|
||||
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInfo, int64, error) {
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
||||
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.TopKKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, 0, errors.New(TopKKey + " not found in search_params")
|
||||
return nil, 0, errors.New(common.TopKKey + " not found in search_params")
|
||||
}
|
||||
topK, err := strconv.ParseInt(topKStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", common.TopKKey, topKStr)
|
||||
}
|
||||
if err := validateLimit(topK); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", common.TopKKey, topK, err)
|
||||
}
|
||||
|
||||
var offset int64
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
|
||||
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.OffsetKey, searchParamsPair)
|
||||
if err == nil {
|
||||
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid", common.OffsetKey, offsetStr)
|
||||
}
|
||||
|
||||
if offset != 0 {
|
||||
if err := validateLimit(offset); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
||||
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", common.OffsetKey, offset, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
queryTopK := topK + offset
|
||||
if err := validateLimit(queryTopK); err != nil {
|
||||
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
||||
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", common.OffsetKey, common.TopKKey, queryTopK, err)
|
||||
}
|
||||
|
||||
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
|
||||
@ -133,20 +133,20 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryIn
|
||||
return nil, 0, errors.New(common.MetricTypeKey + " not found in search_params")
|
||||
}
|
||||
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
||||
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.RoundDecimalKey, searchParamsPair)
|
||||
if err != nil {
|
||||
roundDecimalStr = "-1"
|
||||
}
|
||||
|
||||
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", common.RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
|
||||
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
||||
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", common.RoundDecimalKey, roundDecimalStr)
|
||||
}
|
||||
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, searchParamsPair)
|
||||
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(common.SearchParamsKey, searchParamsPair)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -244,7 +244,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
//fetch search_growing from search param
|
||||
var ignoreGrowing bool
|
||||
for i, kv := range t.request.GetSearchParams() {
|
||||
if kv.GetKey() == IgnoreGrowingKey {
|
||||
if kv.GetKey() == common.IgnoreGrowingKey {
|
||||
ignoreGrowing, err = strconv.ParseBool(kv.GetValue())
|
||||
if err != nil {
|
||||
return errors.New("parse search growing failed")
|
||||
@ -256,9 +256,9 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
t.SearchRequest.IgnoreGrowing = ignoreGrowing
|
||||
|
||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(common.AnnsFieldKey, t.request.GetSearchParams())
|
||||
if err != nil {
|
||||
return errors.New(AnnsFieldKey + " not found in search_params")
|
||||
return errors.New(common.AnnsFieldKey + " not found in search_params")
|
||||
}
|
||||
|
||||
queryInfo, offset, err := parseSearchInfo(t.request.GetSearchParams())
|
||||
@ -327,7 +327,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
||||
// Check if nq is valid:
|
||||
// https://milvus.io/docs/limitations.md
|
||||
if err := validateLimit(nq); err != nil {
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err)
|
||||
return fmt.Errorf("%s [%d] is invalid, %w", common.NQKey, nq, err)
|
||||
}
|
||||
t.SearchRequest.Nq = nq
|
||||
|
||||
|
||||
@ -88,11 +88,11 @@ func createColl(t *testing.T, name string, rc types.RootCoord) {
|
||||
func getValidSearchParams() []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Key: common.AnnsFieldKey,
|
||||
Value: testFloatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Key: common.TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
@ -100,15 +100,15 @@ func getValidSearchParams() []*commonpb.KeyValuePair {
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
Key: common.SearchParamsKey,
|
||||
Value: `{"nprobe": 10}`,
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Key: common.RoundDecimalKey,
|
||||
Value: "-1",
|
||||
},
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Key: common.IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
}}
|
||||
}
|
||||
@ -230,7 +230,7 @@ func TestSearchTask_PreExecute(t *testing.T) {
|
||||
createColl(t, collName, rc)
|
||||
|
||||
task := getSearchTask(t, collName)
|
||||
task.request.SearchParams = getInvalidSearchParams(IgnoreGrowingKey)
|
||||
task.request.SearchParams = getInvalidSearchParams(common.IgnoreGrowingKey)
|
||||
err = task.PreExecute(ctx)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
@ -1852,7 +1852,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
|
||||
sp := getValidSearchParams()
|
||||
sp = append(sp, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Key: common.OffsetKey,
|
||||
Value: strconv.FormatInt(targetOffset, 10),
|
||||
})
|
||||
|
||||
@ -1864,26 +1864,26 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
|
||||
t.Run("parseSearchInfo error", func(t *testing.T) {
|
||||
spNoTopk := []*commonpb.KeyValuePair{{
|
||||
Key: AnnsFieldKey,
|
||||
Key: common.AnnsFieldKey,
|
||||
Value: testFloatVecField}}
|
||||
|
||||
spInvalidTopk := append(spNoTopk, &commonpb.KeyValuePair{
|
||||
Key: TopKKey,
|
||||
Key: common.TopKKey,
|
||||
Value: "invalid",
|
||||
})
|
||||
|
||||
spInvalidTopk65536 := append(spNoTopk, &commonpb.KeyValuePair{
|
||||
Key: TopKKey,
|
||||
Key: common.TopKKey,
|
||||
Value: "65536",
|
||||
})
|
||||
|
||||
spNoMetricType := append(spNoTopk, &commonpb.KeyValuePair{
|
||||
Key: TopKKey,
|
||||
Key: common.TopKKey,
|
||||
Value: "10",
|
||||
})
|
||||
|
||||
spInvalidTopkPlusOffset := append(spNoTopk, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Key: common.OffsetKey,
|
||||
Value: "65535",
|
||||
})
|
||||
|
||||
@ -1894,32 +1894,32 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
|
||||
// no roundDecimal is valid
|
||||
noRoundDecimal := append(spNoSearchParams, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Key: common.SearchParamsKey,
|
||||
Value: `{"nprobe": 10}`,
|
||||
})
|
||||
|
||||
spInvalidRoundDecimal2 := append(noRoundDecimal, &commonpb.KeyValuePair{
|
||||
Key: RoundDecimalKey,
|
||||
Key: common.RoundDecimalKey,
|
||||
Value: "1000",
|
||||
})
|
||||
|
||||
spInvalidRoundDecimal := append(noRoundDecimal, &commonpb.KeyValuePair{
|
||||
Key: RoundDecimalKey,
|
||||
Key: common.RoundDecimalKey,
|
||||
Value: "invalid",
|
||||
})
|
||||
|
||||
spInvalidOffsetNoInt := append(noRoundDecimal, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Key: common.OffsetKey,
|
||||
Value: "invalid",
|
||||
})
|
||||
|
||||
spInvalidOffsetNegative := append(noRoundDecimal, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Key: common.OffsetKey,
|
||||
Value: "-1",
|
||||
})
|
||||
|
||||
spInvalidOffsetTooLarge := append(noRoundDecimal, &commonpb.KeyValuePair{
|
||||
Key: OffsetKey,
|
||||
Key: common.OffsetKey,
|
||||
Value: "16386",
|
||||
})
|
||||
|
||||
|
||||
@ -329,19 +329,19 @@ func constructSearchRequest(
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
Key: common.SearchParamsKey,
|
||||
Value: string(b),
|
||||
},
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Key: common.AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Key: common.TopKKey,
|
||||
Value: strconv.Itoa(topk),
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Key: common.RoundDecimalKey,
|
||||
Value: strconv.Itoa(roundDecimal),
|
||||
},
|
||||
},
|
||||
|
||||
@ -81,10 +81,21 @@ const (
|
||||
SearchParamKey = "search_param"
|
||||
SegmentNumKey = "segment_num"
|
||||
|
||||
IgnoreGrowingKey = "ignore_growing"
|
||||
AnnsFieldKey = "anns_field"
|
||||
NQKey = "nq"
|
||||
SearchParamsKey = "params"
|
||||
RoundDecimalKey = "round_decimal"
|
||||
OffsetKey = "offset"
|
||||
LimitKey = "limit"
|
||||
|
||||
IndexParamsKey = "params"
|
||||
IndexTypeKey = "index_type"
|
||||
MetricTypeKey = "metric_type"
|
||||
DimKey = "dim"
|
||||
|
||||
QueryNumberKey = "$query_number"
|
||||
DistanceKey = "$distance"
|
||||
)
|
||||
|
||||
// Collection properties key
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user