From 6084930854d224494b41ee04e69ff98362a770ce Mon Sep 17 00:00:00 2001 From: congqixia Date: Fri, 25 Apr 2025 10:38:38 +0800 Subject: [PATCH] fix: [GoSDK] Loose rowbased insert data check (#41498) Related to #41460 This PR looses insert data check based on schema. These check shall actually happen at milvus server side. --------- Signed-off-by: Congqi Xia --- client/row/data.go | 244 +++++++++++++---------- client/row/data_test.go | 4 +- tests/go_client/testcases/insert_test.go | 4 +- 3 files changed, 141 insertions(+), 111 deletions(-) diff --git a/client/row/data.go b/client/row/data.go index 8ab74112f2..27acdebecd 100644 --- a/client/row/data.go +++ b/client/row/data.go @@ -23,6 +23,7 @@ import ( "strconv" "github.com/cockroachdb/errors" + "github.com/samber/lo" "github.com/milvus-io/milvus/client/v2/column" "github.com/milvus-io/milvus/client/v2/entity" @@ -60,6 +61,9 @@ const ( DimMax = 65535 ) +// AnyToColumns converts input rows into column-based data. +// when schemas are provided, this method will use 0-th element +// otherwise, it shall try to parse schema from row[0] func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Column, error) { rowsLen := len(rows) if rowsLen == 0 { @@ -70,6 +74,7 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum var err error // if schema not provided, try to parse from row if len(schemas) == 0 { + //nolint rows number checked before sch, err = ParseSchema(rows[0]) if err != nil { return []column.Column{}, err @@ -83,101 +88,31 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum var dynamicCol *column.ColumnJSONBytes nameColumns := make(map[string]column.Column) - for _, field := range sch.Fields { - // skip auto id pk field - if field.PrimaryKey && field.AutoID { - continue - } - - var col column.Column - switch field.DataType { - case entity.FieldTypeBool: - data := make([]bool, 0, rowsLen) - col = column.NewColumnBool(field.Name, data) - case entity.FieldTypeInt8: - data := make([]int8, 0, rowsLen) - col = column.NewColumnInt8(field.Name, data) - case entity.FieldTypeInt16: - data := make([]int16, 0, rowsLen) - col = column.NewColumnInt16(field.Name, data) - case entity.FieldTypeInt32: - data := make([]int32, 0, rowsLen) - col = column.NewColumnInt32(field.Name, data) - case entity.FieldTypeInt64: - data := make([]int64, 0, rowsLen) - col = column.NewColumnInt64(field.Name, data) - case entity.FieldTypeFloat: - data := make([]float32, 0, rowsLen) - col = column.NewColumnFloat(field.Name, data) - case entity.FieldTypeDouble: - data := make([]float64, 0, rowsLen) - col = column.NewColumnDouble(field.Name, data) - case entity.FieldTypeString, entity.FieldTypeVarChar: - data := make([]string, 0, rowsLen) - col = column.NewColumnVarChar(field.Name, data) - case entity.FieldTypeJSON: - data := make([][]byte, 0, rowsLen) - col = column.NewColumnJSONBytes(field.Name, data) - case entity.FieldTypeArray: - col = NewArrayColumn(field) - if col == nil { - return nil, errors.Newf("unsupported element type %s for Array", field.ElementType.String()) - } - case entity.FieldTypeFloatVector: - data := make([][]float32, 0, rowsLen) - dimStr, has := field.TypeParams[entity.TypeParamDim] - if !has { - return []column.Column{}, errors.New("vector field with no dim") - } - dim, err := strconv.ParseInt(dimStr, 10, 64) - if err != nil { - return []column.Column{}, fmt.Errorf("vector field with bad format dim: %s", err.Error()) - } - col = column.NewColumnFloatVector(field.Name, int(dim), data) - case entity.FieldTypeBinaryVector: - data := make([][]byte, 0, rowsLen) - dim, err := field.GetDim() - if err != nil { - return []column.Column{}, err - } - col = column.NewColumnBinaryVector(field.Name, int(dim), data) - case entity.FieldTypeFloat16Vector: - data := make([][]byte, 0, rowsLen) - dim, err := field.GetDim() - if err != nil { - return []column.Column{}, err - } - col = column.NewColumnFloat16Vector(field.Name, int(dim), data) - case entity.FieldTypeBFloat16Vector: - data := make([][]byte, 0, rowsLen) - dim, err := field.GetDim() - if err != nil { - return []column.Column{}, err - } - col = column.NewColumnBFloat16Vector(field.Name, int(dim), data) - case entity.FieldTypeSparseVector: - data := make([]entity.SparseEmbedding, 0, rowsLen) - col = column.NewColumnSparseVectors(field.Name, data) - case entity.FieldTypeInt8Vector: - data := make([][]int8, 0, rowsLen) - dim, err := field.GetDim() - if err != nil { - return []column.Column{}, err - } - col = column.NewColumnInt8Vector(field.Name, int(dim), data) - } - - if field.Nullable { - col.SetNullable(true) - } - - nameColumns[field.Name] = col - } + nameSchemas := lo.SliceToMap(sch.Fields, func(fieldSchema *entity.Field) (string, entity.Field) { + return fieldSchema.Name, *fieldSchema + }) + columnCreators := getColumnCreators(sch) if isDynamic { dynamicCol = column.NewColumnJSONBytes("", make([][]byte, 0, rowsLen)).WithIsDynamic(true) } + // getColumn is a closure to wrap fetch column related to field name + getColumn := func(fieldName string) (column.Column, error) { + // existing one + column, ok := nameColumns[fieldName] + if ok { + return column, nil + } + + fn, ok := columnCreators[fieldName] + if ok { + return fn(rowsLen) + } + + return nil, errors.New("column not found") + } + for _, row := range rows { // collection schema name need not to be same, since receiver could has other names v := reflect.ValueOf(row) @@ -186,31 +121,28 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum return nil, err } - for idx, field := range sch.Fields { - // skip dynamic field if visible - if isDynamic && field.IsDynamic { - continue - } - // skip auto id pk field - if field.PrimaryKey && field.AutoID { + for fieldName, candi := range set { + fieldSch, ok := nameSchemas[fieldName] + if ok && fieldSch.PrimaryKey && fieldSch.AutoID { // remove pk field from candidates set, avoid adding it into dynamic column - delete(set, field.Name) + delete(set, fieldName) continue } - column, ok := nameColumns[field.Name] - if !ok { - return nil, fmt.Errorf("expected unhandled field %s", field.Name) - } - candi, ok := set[field.Name] - if !ok { - return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name) + column, err := getColumn(fieldName) + if err != nil { + // ignore candidate not exist in schema for now + // if dynamic schema enabled, left candidates will be processed + // TODO @congqixia, add strict mode if needed + continue } - err := column.AppendValue(candi.v.Interface()) + nameColumns[fieldName] = column + + err = column.AppendValue(candi.v.Interface()) if err != nil { return nil, err } - delete(set, field.Name) + delete(set, fieldName) } if isDynamic { @@ -238,6 +170,104 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum return columns, nil } +type columnCreator func(int) (column.Column, error) + +func getColumnCreators(sch *entity.Schema) map[string]columnCreator { + result := make(map[string]columnCreator) + for _, field := range sch.Fields { + // skip auto id pk field + // if field.PrimaryKey && field.AutoID { + // continue + // } + field := field + result[field.Name] = func(rowsLen int) (column.Column, error) { + var col column.Column + switch field.DataType { + case entity.FieldTypeBool: + data := make([]bool, 0, rowsLen) + col = column.NewColumnBool(field.Name, data) + case entity.FieldTypeInt8: + data := make([]int8, 0, rowsLen) + col = column.NewColumnInt8(field.Name, data) + case entity.FieldTypeInt16: + data := make([]int16, 0, rowsLen) + col = column.NewColumnInt16(field.Name, data) + case entity.FieldTypeInt32: + data := make([]int32, 0, rowsLen) + col = column.NewColumnInt32(field.Name, data) + case entity.FieldTypeInt64: + data := make([]int64, 0, rowsLen) + col = column.NewColumnInt64(field.Name, data) + case entity.FieldTypeFloat: + data := make([]float32, 0, rowsLen) + col = column.NewColumnFloat(field.Name, data) + case entity.FieldTypeDouble: + data := make([]float64, 0, rowsLen) + col = column.NewColumnDouble(field.Name, data) + case entity.FieldTypeString, entity.FieldTypeVarChar: + data := make([]string, 0, rowsLen) + col = column.NewColumnVarChar(field.Name, data) + case entity.FieldTypeJSON: + data := make([][]byte, 0, rowsLen) + col = column.NewColumnJSONBytes(field.Name, data) + case entity.FieldTypeArray: + col = NewArrayColumn(field) + if col == nil { + return nil, errors.Newf("unsupported element type %s for Array", field.ElementType.String()) + } + case entity.FieldTypeFloatVector: + data := make([][]float32, 0, rowsLen) + dimStr, has := field.TypeParams[entity.TypeParamDim] + if !has { + return nil, errors.New("vector field with no dim") + } + dim, err := strconv.ParseInt(dimStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("vector field with bad format dim: %s", err.Error()) + } + col = column.NewColumnFloatVector(field.Name, int(dim), data) + case entity.FieldTypeBinaryVector: + data := make([][]byte, 0, rowsLen) + dim, err := field.GetDim() + if err != nil { + return nil, err + } + col = column.NewColumnBinaryVector(field.Name, int(dim), data) + case entity.FieldTypeFloat16Vector: + data := make([][]byte, 0, rowsLen) + dim, err := field.GetDim() + if err != nil { + return nil, err + } + col = column.NewColumnFloat16Vector(field.Name, int(dim), data) + case entity.FieldTypeBFloat16Vector: + data := make([][]byte, 0, rowsLen) + dim, err := field.GetDim() + if err != nil { + return nil, err + } + col = column.NewColumnBFloat16Vector(field.Name, int(dim), data) + case entity.FieldTypeSparseVector: + data := make([]entity.SparseEmbedding, 0, rowsLen) + col = column.NewColumnSparseVectors(field.Name, data) + case entity.FieldTypeInt8Vector: + data := make([][]int8, 0, rowsLen) + dim, err := field.GetDim() + if err != nil { + return nil, err + } + col = column.NewColumnInt8Vector(field.Name, int(dim), data) + } + + if field.Nullable { + col.SetNullable(true) + } + return col, nil + } + } + return result +} + func NewArrayColumn(f *entity.Field) column.Column { switch f.ElementType { case entity.FieldTypeBool: diff --git a/client/row/data_test.go b/client/row/data_test.go index 064475dcbe..1f1c8334c1 100644 --- a/client/row/data_test.go +++ b/client/row/data_test.go @@ -26,7 +26,7 @@ type ValidStruct struct { type ValidStruct2 struct { ID int64 `milvus:"primary_key"` Vector [16]float32 - Vector2 [4]byte + Attr1 float64 Ignored bool `milvus:"-"` } @@ -110,7 +110,7 @@ func (s *RowsSuite) TestRowsToColumns() { _, err = AnyToColumns([]any{&ValidStruct{}}, &entity.Schema{ Fields: []*entity.Field{ { - Name: "int64", + Name: "Attr1", DataType: entity.FieldTypeInt64, }, }, diff --git a/tests/go_client/testcases/insert_test.go b/tests/go_client/testcases/insert_test.go index 1acb3feb77..f446ca16ba 100644 --- a/tests/go_client/testcases/insert_test.go +++ b/tests/go_client/testcases/insert_test.go @@ -721,7 +721,7 @@ func TestInsertRowFieldNameNotMatch(t *testing.T) { // insert rows, with json key name: int64 rows := hp.GenInt64VecRows(10, false, false, *hp.TNewDataOption()) _, errInsert := mc.Insert(ctx, client.NewRowBasedInsertOption(schema.CollectionName, rows...)) - common.CheckErr(t, errInsert, false, "row 0 does not has field pk") + common.CheckErr(t, errInsert, false, "fieldSchema(pk) has no corresponding fieldData pass in") } // test field name: pk, row json name: int64 @@ -806,7 +806,7 @@ func TestInsertDisableAutoIDRow(t *testing.T) { rowsWithoutPk = append(rowsWithoutPk, &baseRow) } _, err1 := mc.Insert(ctx, client.NewRowBasedInsertOption(schema.CollectionName, rowsWithoutPk...)) - common.CheckErr(t, err1, false, "row 0 does not has field int64") + common.CheckErr(t, err1, false, "fieldSchema(int64) has no corresponding fieldData pass in") } func TestInsertEnableAutoIDRow(t *testing.T) {