diff --git a/client/milvusclient/write_options.go b/client/milvusclient/write_options.go index ac8ef07a88..aa3534ab64 100644 --- a/client/milvusclient/write_options.go +++ b/client/milvusclient/write_options.go @@ -302,7 +302,8 @@ func NewColumnBasedInsertOption(collName string, columns ...column.Column) *colu type rowBasedDataOption struct { *columnBasedDataOption - rows []any + rows []any + keepAutoIDPk bool // keep user passed auto id pk field } func NewRowBasedInsertOption(collName string, rows ...any) *rowBasedDataOption { @@ -310,12 +311,13 @@ func NewRowBasedInsertOption(collName string, rows ...any) *rowBasedDataOption { columnBasedDataOption: &columnBasedDataOption{ collName: collName, }, - rows: rows, + rows: rows, + keepAutoIDPk: false, } } func (opt *rowBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb.InsertRequest, error) { - columns, err := row.AnyToColumns(opt.rows, coll.Schema) + columns, err := row.AnyToColumns(opt.rows, opt.keepAutoIDPk, coll.Schema) if err != nil { return nil, err } @@ -333,7 +335,7 @@ func (opt *rowBasedDataOption) InsertRequest(coll *entity.Collection) (*milvuspb } func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb.UpsertRequest, error) { - columns, err := row.AnyToColumns(opt.rows, coll.Schema) + columns, err := row.AnyToColumns(opt.rows, opt.keepAutoIDPk, coll.Schema) if err != nil { return nil, err } @@ -373,6 +375,11 @@ func (opt *rowBasedDataOption) WriteBackPKs(sch *entity.Schema, pks column.Colum return nil } +func (opt *rowBasedDataOption) WithKeepAutoIDPk(keepPk bool) *rowBasedDataOption { + opt.keepAutoIDPk = keepPk + return opt +} + type DeleteOption interface { Request() *milvuspb.DeleteRequest } diff --git a/client/row/data.go b/client/row/data.go index 27acdebecd..afdddbdace 100644 --- a/client/row/data.go +++ b/client/row/data.go @@ -64,7 +64,7 @@ const ( // AnyToColumns converts input rows into column-based data. // when schemas are provided, this method will use 0-th element // otherwise, it shall try to parse schema from row[0] -func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Column, error) { +func AnyToColumns(rows []interface{}, keepPkField bool, schemas ...*entity.Schema) ([]column.Column, error) { rowsLen := len(rows) if rowsLen == 0 { return []column.Column{}, errors.New("0 length column") @@ -123,7 +123,7 @@ func AnyToColumns(rows []interface{}, schemas ...*entity.Schema) ([]column.Colum for fieldName, candi := range set { fieldSch, ok := nameSchemas[fieldName] - if ok && fieldSch.PrimaryKey && fieldSch.AutoID { + if ok && fieldSch.PrimaryKey && fieldSch.AutoID && !keepPkField { // remove pk field from candidates set, avoid adding it into dynamic column delete(set, fieldName) continue diff --git a/client/row/data_test.go b/client/row/data_test.go index 1f1c8334c1..dc85bcf43c 100644 --- a/client/row/data_test.go +++ b/client/row/data_test.go @@ -41,11 +41,11 @@ type RowsSuite struct { func (s *RowsSuite) TestRowsToColumns() { s.Run("valid_cases", func() { - columns, err := AnyToColumns([]any{&ValidStruct{}}) + columns, err := AnyToColumns([]any{&ValidStruct{}}, false) s.Nil(err) s.Equal(10, len(columns)) - columns, err = AnyToColumns([]any{&ValidStruct2{}}) + columns, err = AnyToColumns([]any{&ValidStruct2{}}, false) s.Nil(err) s.Equal(3, len(columns)) }) @@ -55,7 +55,7 @@ func (s *RowsSuite) TestRowsToColumns() { ID int64 `milvus:"primary_key;auto_id"` Vector []float32 `milvus:"dim:32"` } - columns, err := AnyToColumns([]any{&AutoPK{}}) + columns, err := AnyToColumns([]any{&AutoPK{}}, false) s.Nil(err) s.Require().Equal(1, len(columns)) s.Equal("Vector", columns[0].Name()) @@ -66,7 +66,7 @@ func (s *RowsSuite) TestRowsToColumns() { ID int64 `milvus:"primary_key;auto_id"` Vector []byte `milvus:"dim:16;vector_type:bf16"` } - columns, err := AnyToColumns([]any{&BF16Struct{}}) + columns, err := AnyToColumns([]any{&BF16Struct{}}, false) s.Nil(err) s.Require().Equal(1, len(columns)) s.Equal("Vector", columns[0].Name()) @@ -78,7 +78,7 @@ func (s *RowsSuite) TestRowsToColumns() { ID int64 `milvus:"primary_key;auto_id"` Vector []byte `milvus:"dim:16;vector_type:fp16"` } - columns, err := AnyToColumns([]any{&FP16Struct{}}) + columns, err := AnyToColumns([]any{&FP16Struct{}}, false) s.Nil(err) s.Require().Equal(1, len(columns)) s.Equal("Vector", columns[0].Name()) @@ -90,7 +90,7 @@ func (s *RowsSuite) TestRowsToColumns() { ID int64 `milvus:"primary_key;auto_id"` Vector []int8 `milvus:"dim:16;vector_type:int8"` } - columns, err := AnyToColumns([]any{&Int8Struct{}}) + columns, err := AnyToColumns([]any{&Int8Struct{}}, false) s.Nil(err) s.Require().Equal(1, len(columns)) s.Equal("Vector", columns[0].Name()) @@ -99,15 +99,15 @@ func (s *RowsSuite) TestRowsToColumns() { s.Run("invalid_cases", func() { // empty input - _, err := AnyToColumns([]any{}) + _, err := AnyToColumns([]any{}, false) s.NotNil(err) // incompatible rows - _, err = AnyToColumns([]any{&ValidStruct{}, &ValidStruct2{}}) + _, err = AnyToColumns([]any{&ValidStruct{}, &ValidStruct2{}}, false) s.NotNil(err) // schema & row not compatible - _, err = AnyToColumns([]any{&ValidStruct{}}, &entity.Schema{ + _, err = AnyToColumns([]any{&ValidStruct{}}, false, &entity.Schema{ Fields: []*entity.Field{ { Name: "Attr1", @@ -121,7 +121,7 @@ func (s *RowsSuite) TestRowsToColumns() { func (s *RowsSuite) TestDynamicSchema() { s.Run("all_fallback_dynamic", func() { - columns, err := AnyToColumns([]any{&ValidStruct{}}, + columns, err := AnyToColumns([]any{&ValidStruct{}}, false, entity.NewSchema().WithDynamicFieldEnabled(true), ) s.NoError(err) @@ -129,7 +129,7 @@ func (s *RowsSuite) TestDynamicSchema() { }) s.Run("dynamic_not_found", func() { - _, err := AnyToColumns([]any{&ValidStruct{}}, + _, err := AnyToColumns([]any{&ValidStruct{}}, false, entity.NewSchema().WithField( entity.NewField().WithName("ID").WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true), ).WithDynamicFieldEnabled(true), diff --git a/tests/go_client/testcases/insert_test.go b/tests/go_client/testcases/insert_test.go index f446ca16ba..e1dfdfe2f0 100644 --- a/tests/go_client/testcases/insert_test.go +++ b/tests/go_client/testcases/insert_test.go @@ -601,6 +601,41 @@ func TestInsertDefaultRows(t *testing.T) { } } +func TestInsertDefaultRowsWithKeepAutoIDPk(t *testing.T) { + t.Parallel() + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + cp := hp.NewCreateCollectionParams(hp.Int64Vec) + _, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, cp, hp.TNewFieldsOption().TWithAutoID(true), hp.TNewSchemaOption()) + log.Info("fields", zap.Any("FieldNames", schema.Fields)) + err := mc.AlterCollectionProperties(ctx, client.NewAlterCollectionPropertiesOption(schema.CollectionName).WithProperty("allow_insert_auto_id", true)) + common.CheckErr(t, err, true) + + // insert rows + rows := hp.GenInt64VecRows(common.DefaultNb, false, false, *hp.TNewDataOption()) + log.Info("rows data", zap.Any("rows[8]", rows[8])) + ids, err := mc.Insert(ctx, client.NewRowBasedInsertOption(schema.CollectionName, rows...).WithKeepAutoIDPk(true)) + common.CheckErr(t, err, true) + int64Values := make([]int64, 0, common.DefaultNb) + for i := 0; i < common.DefaultNb; i++ { + int64Values = append(int64Values, int64(i+1)) + } + common.CheckInsertResult(t, column.NewColumnInt64(common.DefaultInt64FieldName, int64Values), ids) + require.Equal(t, ids.InsertCount, int64(common.DefaultNb)) + + // flush and check row count + flushTask, errFlush := mc.Flush(ctx, client.NewFlushOption(schema.CollectionName)) + common.CheckErr(t, errFlush, true) + errFlush = flushTask.Await(ctx) + common.CheckErr(t, errFlush, true) + + // check collection stats + stats, err := mc.GetCollectionStats(ctx, client.NewGetCollectionStatsOption(schema.CollectionName)) + common.CheckErr(t, err, true) + require.Equal(t, map[string]string{common.RowCount: strconv.Itoa(common.DefaultNb)}, stats) +} + // test insert rows enable or disable dynamic field func TestInsertAllFieldsRows(t *testing.T) { t.Skip("https://github.com/milvus-io/milvus/issues/33459")