diff --git a/internal/mysqld/parser/antlrparser/ast_builder.go b/internal/mysqld/parser/antlrparser/ast_builder.go index 53aabe2572..1ef2d393a1 100644 --- a/internal/mysqld/parser/antlrparser/ast_builder.go +++ b/internal/mysqld/parser/antlrparser/ast_builder.go @@ -4,6 +4,8 @@ import ( "fmt" "strconv" + "github.com/milvus-io/milvus/internal/mysqld/sqlutil" + "github.com/milvus-io/milvus/internal/mysqld/planner" "github.com/antlr/antlr4/runtime/Go/antlr/v4" @@ -172,6 +174,17 @@ func (v *AstBuilder) VisitQuerySpecification(ctx *parsergen.QuerySpecificationCo } } + annsCtx := ctx.AnnsClause() + if annsCtx != nil { + r := annsCtx.Accept(v) + if err := GetError(r); err != nil { + return err + } + if n := GetANNSClause(r); n != nil { + opts = append(opts, planner.WithANNS(n)) + } + } + limitCtx := ctx.LimitClause() if limitCtx != nil { r := limitCtx.Accept(v) @@ -290,6 +303,167 @@ func (v *AstBuilder) VisitFromClause(ctx *parsergen.FromClauseContext) interface return planner.NewNodeFromClause(text, tableSources, opts...) } +func (v *AstBuilder) VisitAnnsClause(ctx *parsergen.AnnsClauseContext) interface{} { + text := GetOriginalText(ctx) + + var column *planner.NodeFullColumnName + var vectors []*planner.NodeVector + + columnCtx := ctx.FullColumnName() + if columnCtx != nil { + r := columnCtx.Accept(v) + if err := GetError(r); err != nil { + return err + } + column = planner.NewNodeFullColumnName(GetOriginalText(columnCtx), r.(string)) + } + + vectorsCtx := ctx.AnnsVectors() + if vectorsCtx != nil { + r := vectorsCtx.Accept(v) + if err := GetError(r); err != nil { + return err + } + vectors = r.([]*planner.NodeVector) + } + + var opts []planner.NodeANNSClauseOption + + var paramsCtx parsergen.IAnnsParamsClauseContext + + // Don't use ctx.AnnsParamsClause() directly, especially when the nq is too large. + // In fact, ctx.AnnsParamsClause() will iterate all children from the beginning index, which + // is not very efficient. + children := ctx.GetChildren() + lenOfChildren := len(children) + for i := lenOfChildren - 1; i >= 0; i-- { + if childCtx, ok := children[i].(parsergen.IAnnsParamsClauseContext); ok { + paramsCtx = childCtx + break + } + } + + if paramsCtx != nil { + r := paramsCtx.Accept(v) + if err := GetError(r); err != nil { + return err + } + if n := GetKVPairs(r); n != nil { + opts = append(opts, planner.NodeANNSClauseWithParams(n)) + } + } + + return planner.NewNodeANNSClause(text, column, vectors, opts...) +} + +func (v *AstBuilder) VisitAnnsVectors(ctx *parsergen.AnnsVectorsContext) interface{} { + var vectors []*planner.NodeVector + + allVectorsCtx := ctx.AllAnnsVector() + + for _, vectorCtx := range allVectorsCtx { + r := vectorCtx.Accept(v) + if err := GetError(r); err != nil { + return err + } + if n := GetVector(r); n != nil { + vectors = append(vectors, n) + } + } + + return vectors +} + +func (v *AstBuilder) VisitAnnsVector(ctx *parsergen.AnnsVectorContext) interface{} { + if ctx.BIT_STRING() != nil { + return fmt.Errorf("binary vector is not supported") + } + + var floatArray []float32 + + floatArrayCtx := ctx.FloatArray() + if floatArrayCtx != nil { + r := floatArrayCtx.Accept(v) + if err := GetError(r); err != nil { + return err + } + floatArray = r.([]float32) + } + + return planner.NewNodeVector(planner.WithFloatVector(planner.NewNodeFloatVector(floatArray))) +} + +func (v *AstBuilder) VisitFloatArray(ctx *parsergen.FloatArrayContext) interface{} { + var floatArray []float32 + + allDecimalCtx := ctx.AllDecimalLiteral() + for _, childCtx := range allDecimalCtx { + r := childCtx.Accept(v) + switch rWithType := r.(type) { + case int64: + floatArray = append(floatArray, float32(rWithType)) + case float32: + floatArray = append(floatArray, rWithType) + case float64: + floatArray = append(floatArray, float32(rWithType)) + case error: + return rWithType + default: + // TODO + return fmt.Errorf("failed to parse float vector: %s", GetOriginalText(childCtx)) + } + } + + return floatArray +} + +func (v *AstBuilder) VisitAnnsParamsClause(ctx *parsergen.AnnsParamsClauseContext) interface{} { + return ctx.KvPairs().Accept(v) +} + +func (v *AstBuilder) VisitKvPairs(ctx *parsergen.KvPairsContext) interface{} { + allKvPairs := ctx.AllKvPair() + lenOfPairs := len(allKvPairs) + + if lenOfPairs == 0 { + return nil + } + + pairs := planner.NewNodeKVPairs() + + for _, child := range allKvPairs { + childCtx := child.(*parsergen.KvPairContext) + key := childCtx.ID().GetText() + value := childCtx.Value().Accept(v) + switch tv := value.(type) { + case string: + pairs.Insert(key, tv) + case int: + pairs.Insert(key, strconv.Itoa(tv)) + case int32: + pairs.Insert(key, strconv.Itoa(int(tv))) + case int64: + pairs.Insert(key, strconv.Itoa(int(tv))) + case float32: + pairs.Insert(key, sqlutil.Float32ToString(tv)) + case float64: + pairs.Insert(key, sqlutil.Float64ToString(tv)) + default: + return fmt.Errorf("invalid type: %s", GetOriginalText(childCtx)) + } + } + + return pairs +} + +func (v *AstBuilder) VisitValue(ctx *parsergen.ValueContext) interface{} { + if idCtx := ctx.ID(); idCtx != nil { + return idCtx.GetText() + } + + return ctx.Constant().Accept(v) +} + func (v *AstBuilder) VisitTableSources(ctx *parsergen.TableSourcesContext) interface{} { // Should not be visited. return nil diff --git a/internal/mysqld/parser/antlrparser/node_ret.go b/internal/mysqld/parser/antlrparser/node_ret.go index d79650457a..8d5b9c37f4 100644 --- a/internal/mysqld/parser/antlrparser/node_ret.go +++ b/internal/mysqld/parser/antlrparser/node_ret.go @@ -77,6 +77,14 @@ func GetFromClause(obj interface{}) *planner.NodeFromClause { return n } +func GetANNSClause(obj interface{}) *planner.NodeANNSClause { + n, ok := obj.(*planner.NodeANNSClause) + if !ok { + return nil + } + return n +} + func GetLimitClause(obj interface{}) *planner.NodeLimitClause { n, ok := obj.(*planner.NodeLimitClause) if !ok { @@ -140,3 +148,19 @@ func GetExpressions(obj interface{}) *planner.NodeExpressions { } return n } + +func GetKVPairs(obj interface{}) *planner.NodeKVPairs { + n, ok := obj.(*planner.NodeKVPairs) + if !ok { + return nil + } + return n +} + +func GetVector(obj interface{}) *planner.NodeVector { + n, ok := obj.(*planner.NodeVector) + if !ok { + return nil + } + return n +} diff --git a/internal/mysqld/parser/antlrparser/parser_test.go b/internal/mysqld/parser/antlrparser/parser_test.go index 67808eec96..1b02bbc1b5 100644 --- a/internal/mysqld/parser/antlrparser/parser_test.go +++ b/internal/mysqld/parser/antlrparser/parser_test.go @@ -4,9 +4,10 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/assert" - + "github.com/milvus-io/milvus/internal/mysqld/parser" "github.com/milvus-io/milvus/internal/mysqld/planner" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) func Test_antlrParser_Parse(t *testing.T) { @@ -48,3 +49,35 @@ func Test_antlrParser_Parse(t *testing.T) { debug(t, sql) } } + +type ANNSSuite struct { + suite.Suite + + p parser.Parser +} + +func (suite *ANNSSuite) SetupTest() { + suite.p = NewAntlrParser() +} + +func (suite *ANNSSuite) TearDownTest() {} + +func TestANNSSuite(t *testing.T) { + suite.Run(t, new(ANNSSuite)) +} + +func (suite *ANNSSuite) TestFloatVector() { + sql := ` +select query_number, id, distance +from t +where id >= 1000 and id <= 10000 +anns by feature -> ([0.23, 0.21], [0.24, 0.26]) PARAMS = (nprobe=1, ef=5) +limit 100 +` + + plan, warns, err := suite.p.Parse(sql) + suite.NoError(err) + suite.Nil(warns) + + planner.NewTreeUtils().PrettyPrintHrn(GetSqlStatements(plan.Node)) +} diff --git a/internal/mysqld/planner/anns_clause.go b/internal/mysqld/planner/anns_clause.go new file mode 100644 index 0000000000..b6273e6a3a --- /dev/null +++ b/internal/mysqld/planner/anns_clause.go @@ -0,0 +1,55 @@ +package planner + +import ( + "fmt" + + "github.com/moznion/go-optional" +) + +type NodeANNSClause struct { + baseNode + Column *NodeFullColumnName + Vectors []*NodeVector + Params optional.Option[*NodeKVPairs] +} + +func (n *NodeANNSClause) String() string { + s := fmt.Sprintf("NodeANNSClause, Column: %s, Nq: %d", n.Column.String(), len(n.Vectors)) + if n.Params.IsSome() { + s += fmt.Sprintf(", Params: %s", n.Params.Unwrap().String()) + } + return s +} + +func (n *NodeANNSClause) GetChildren() []Node { + // return []Node{n.Column} + return nil +} + +func (n *NodeANNSClause) Accept(v Visitor) interface{} { + return v.VisitANNSClause(n) +} + +type NodeANNSClauseOption func(*NodeANNSClause) + +func NodeANNSClauseWithParams(p *NodeKVPairs) NodeANNSClauseOption { + return func(n *NodeANNSClause) { + n.Params = optional.Some(p) + } +} + +func (n *NodeANNSClause) apply(opts ...NodeANNSClauseOption) { + for _, opt := range opts { + opt(n) + } +} + +func NewNodeANNSClause(text string, column *NodeFullColumnName, vectors []*NodeVector, opts ...NodeANNSClauseOption) *NodeANNSClause { + n := &NodeANNSClause{ + baseNode: newBaseNode(text), + Column: column, + Vectors: vectors, + } + n.apply(opts...) + return n +} diff --git a/internal/mysqld/planner/binary_vector.go b/internal/mysqld/planner/binary_vector.go new file mode 100644 index 0000000000..5f57e9c0a2 --- /dev/null +++ b/internal/mysqld/planner/binary_vector.go @@ -0,0 +1,6 @@ +package planner + +type NodeBinaryVector struct { + baseNode + //TODO +} diff --git a/internal/mysqld/planner/float_vector.go b/internal/mysqld/planner/float_vector.go new file mode 100644 index 0000000000..a9c7a81e24 --- /dev/null +++ b/internal/mysqld/planner/float_vector.go @@ -0,0 +1,11 @@ +package planner + +type NodeFloatVector struct { + Array []float32 +} + +func NewNodeFloatVector(arr []float32) *NodeFloatVector { + return &NodeFloatVector{ + Array: arr, + } +} diff --git a/internal/mysqld/planner/kv_pair.go b/internal/mysqld/planner/kv_pair.go new file mode 100644 index 0000000000..abd52a151e --- /dev/null +++ b/internal/mysqld/planner/kv_pair.go @@ -0,0 +1,23 @@ +package planner + +import "encoding/json" + +type NodeKVPairs struct { + KVs map[string]string +} + +func (n *NodeKVPairs) Insert(key, value string) { + n.KVs[key] = value +} + +func (n *NodeKVPairs) String() string { + // How could `Marshal` return error here? + bs, _ := json.Marshal(n.KVs) + return string(bs) +} + +func NewNodeKVPairs() *NodeKVPairs { + return &NodeKVPairs{ + KVs: make(map[string]string), + } +} diff --git a/internal/mysqld/planner/query_specification.go b/internal/mysqld/planner/query_specification.go index 52050b2dee..60543bfc7a 100644 --- a/internal/mysqld/planner/query_specification.go +++ b/internal/mysqld/planner/query_specification.go @@ -7,6 +7,7 @@ type NodeQuerySpecification struct { SelectSpecs []*NodeSelectSpec SelectElements []*NodeSelectElement From optional.Option[*NodeFromClause] + Anns optional.Option[*NodeANNSClause] Limit optional.Option[*NodeLimitClause] } @@ -16,18 +17,27 @@ func (n *NodeQuerySpecification) String() string { func (n *NodeQuerySpecification) GetChildren() []Node { children := make([]Node, 0, len(n.SelectSpecs)+len(n.SelectElements)+2) + for _, child := range n.SelectSpecs { children = append(children, child) } + for _, child := range n.SelectElements { children = append(children, child) } + if n.From.IsSome() { children = append(children, n.From.Unwrap()) } + + if n.Anns.IsSome() { + children = append(children, n.Anns.Unwrap()) + } + if n.Limit.IsSome() { children = append(children, n.Limit.Unwrap()) } + return children } @@ -49,6 +59,12 @@ func WithFrom(from *NodeFromClause) NodeQuerySpecificationOption { } } +func WithANNS(anns *NodeANNSClause) NodeQuerySpecificationOption { + return func(n *NodeQuerySpecification) { + n.Anns = optional.Some(anns) + } +} + func WithLimit(Limit *NodeLimitClause) NodeQuerySpecificationOption { return func(n *NodeQuerySpecification) { n.Limit = optional.Some(Limit) diff --git a/internal/mysqld/planner/vector.go b/internal/mysqld/planner/vector.go new file mode 100644 index 0000000000..5fcf3c7cd7 --- /dev/null +++ b/internal/mysqld/planner/vector.go @@ -0,0 +1,27 @@ +package planner + +import "github.com/moznion/go-optional" + +type NodeVector struct { + FloatVector optional.Option[*NodeFloatVector] +} + +type NodeVectorOption func(*NodeVector) + +func (n *NodeVector) apply(opts ...NodeVectorOption) { + for _, opt := range opts { + opt(n) + } +} + +func WithFloatVector(v *NodeFloatVector) NodeVectorOption { + return func(n *NodeVector) { + n.FloatVector = optional.Some(v) + } +} + +func NewNodeVector(opts ...NodeVectorOption) *NodeVector { + n := &NodeVector{} + n.apply(opts...) + return n +} diff --git a/internal/mysqld/planner/visitor.go b/internal/mysqld/planner/visitor.go index 4c1864dfc1..cd3ccb7439 100644 --- a/internal/mysqld/planner/visitor.go +++ b/internal/mysqld/planner/visitor.go @@ -31,4 +31,14 @@ type Visitor interface { VisitUnaryExpressionAtom(n *NodeUnaryExpressionAtom) interface{} VisitNestedExpressionAtom(n *NodeNestedExpressionAtom) interface{} VisitConstant(n *NodeConstant) interface{} + + /* + // In fact, these structs are not enough to be a node. + // They themselves alone don't make any sense. Just regard them as parameters. + VisitFloatVector(n *NodeFloatVector) interface{} + VisitVector(n *NodeVector) interface{} + VisitKVPairs(n *NodeKVPairs) interface{} + */ + + VisitANNSClause(*NodeANNSClause) interface{} } diff --git a/internal/mysqld/planner/visitor_expression_text_restorer.go b/internal/mysqld/planner/visitor_expression_text_restorer.go index 7eeb42f845..1e4eee3582 100644 --- a/internal/mysqld/planner/visitor_expression_text_restorer.go +++ b/internal/mysqld/planner/visitor_expression_text_restorer.go @@ -3,6 +3,8 @@ package planner import ( "fmt" "strconv" + + "github.com/milvus-io/milvus/internal/mysqld/sqlutil" ) // TODO: remove this after execution engine is ready. @@ -120,7 +122,7 @@ func (v *exprTextRestorer) VisitConstant(n *NodeConstant) interface{} { return strconv.FormatBool(n.BooleanLiteral.Unwrap()) } if n.RealLiteral.IsSome() { - return strconv.FormatFloat(n.RealLiteral.Unwrap(), 'g', 10, 14) + return sqlutil.Float64ToString(n.RealLiteral.Unwrap()) } return "" } diff --git a/internal/mysqld/planner/visitor_json.go b/internal/mysqld/planner/visitor_json.go index f45f1813f1..1eb7ae3427 100644 --- a/internal/mysqld/planner/visitor_json.go +++ b/internal/mysqld/planner/visitor_json.go @@ -1,6 +1,10 @@ package planner -import "strconv" +import ( + "strconv" + + "github.com/milvus-io/milvus/internal/mysqld/sqlutil" +) type jsonVisitor struct { } @@ -362,12 +366,22 @@ func (v jsonVisitor) VisitConstant(n *NodeConstant) interface{} { } if n.RealLiteral.IsSome() { - j["real_literal"] = strconv.FormatFloat(n.RealLiteral.Unwrap(), 'f', -1, 64) + j["real_literal"] = sqlutil.Float64ToString(n.RealLiteral.Unwrap()) } return j } +func (v jsonVisitor) VisitANNSClause(n *NodeANNSClause) interface{} { + // leaf node. + + j := map[string]interface{}{} + + j["anns"] = n.String() + + return j +} + func NewJSONVisitor() Visitor { return &jsonVisitor{} } diff --git a/internal/mysqld/sqlutil/conv.go b/internal/mysqld/sqlutil/conv.go new file mode 100644 index 0000000000..9a3b7b0674 --- /dev/null +++ b/internal/mysqld/sqlutil/conv.go @@ -0,0 +1,11 @@ +package sqlutil + +import "strconv" + +func Float32ToString(f float32) string { + return strconv.FormatFloat(float64(f), 'f', -1, 32) +} + +func Float64ToString(f float64) string { + return strconv.FormatFloat(f, 'f', -1, 64) +}