From 14563ad2b3ed7969ad9a7c9f684a13610770c64c Mon Sep 17 00:00:00 2001 From: groot Date: Wed, 28 May 2025 18:02:28 +0800 Subject: [PATCH] enhance: bulkinsert handles nullable/default (#42127) issue: https://github.com/milvus-io/milvus/issues/42096, https://github.com/milvus-io/milvus/issues/42130 Signed-off-by: yhmo --- internal/datanode/importv2/hash.go | 35 ++- internal/datanode/importv2/task_import.go | 6 +- internal/datanode/importv2/util.go | 185 ++++++++++++- internal/datanode/importv2/util_test.go | 286 ++++++++++++++++++++- internal/util/importutilv2/parquet/util.go | 14 +- 5 files changed, 502 insertions(+), 24 deletions(-) diff --git a/internal/datanode/importv2/hash.go b/internal/datanode/importv2/hash.go index b3eba4d03b..d05c9d254f 100644 --- a/internal/datanode/importv2/hash.go +++ b/internal/datanode/importv2/hash.go @@ -104,6 +104,30 @@ func HashDeleteData(task Task, delData *storage.DeleteData) ([]*storage.DeleteDa return res, nil } +// this method is only for GetRowsStats() to get a row from storage.InsertData +// the GetRowsStats() is called by PreImportTask, some of nullable/default_value fields in the storage.InsertData could be zero row +func getRowFromInsertData(rows *storage.InsertData, i int) map[int64]interface{} { + res := make(map[int64]interface{}) + for field, data := range rows.Data { + if data.RowNum() > i { + res[field] = data.GetRow(i) + } + } + return res +} + +// this method is only for GetRowsStats() to get a row from storage.InsertData +// the GetRowsStats() is called by PreImportTask, some of nullable/default_value fields in the storage.InsertData could be zero row +func getRowSizeFromInsertData(rows *storage.InsertData, i int) int { + size := 0 + for _, data := range rows.Data { + if data.RowNum() > i { + size += data.GetRowSize(i) + } + } + return size +} + func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.PartitionImportStats, error) { var ( schema = task.GetSchema() @@ -127,7 +151,7 @@ func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.Parti hashDataSize[i] = make([]int, partitionNum) } - rowNum := GetInsertDataRowCount(rows, schema) + rowNum, _ := GetInsertDataRowCount(rows, schema) if pkField.GetAutoID() { fn := hashByPartition(int64(partitionNum), partKeyField) rows.Data = lo.PickBy(rows.Data, func(fieldID int64, _ storage.FieldData) bool { @@ -136,9 +160,10 @@ func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.Parti hashByPartRowsCount := make([]int, partitionNum) hashByPartDataSize := make([]int, partitionNum) for i := 0; i < rowNum; i++ { - p := fn(rows.GetRow(i)[id2]) + row := getRowFromInsertData(rows, i) + p := fn(row[id2]) hashByPartRowsCount[p]++ - hashByPartDataSize[p] += rows.GetRowSize(i) + hashByPartDataSize[p] += getRowSizeFromInsertData(rows, i) } // When autoID is enabled, the generated IDs will be evenly hashed across all channels. // Therefore, here we just assign an average number of rows to each channel. @@ -152,10 +177,10 @@ func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.Parti f1 := hashByVChannel(int64(channelNum), pkField) f2 := hashByPartition(int64(partitionNum), partKeyField) for i := 0; i < rowNum; i++ { - row := rows.GetRow(i) + row := getRowFromInsertData(rows, i) p1, p2 := f1(row[id1]), f2(row[id2]) hashRowsCount[p1][p2]++ - hashDataSize[p1][p2] += rows.GetRowSize(i) + hashDataSize[p1][p2] += getRowSizeFromInsertData(rows, i) } } diff --git a/internal/datanode/importv2/task_import.go b/internal/datanode/importv2/task_import.go index 65861e0ff9..0eea6cd274 100644 --- a/internal/datanode/importv2/task_import.go +++ b/internal/datanode/importv2/task_import.go @@ -191,7 +191,7 @@ func (t *ImportTask) importFile(reader importutilv2.Reader) error { } return err } - rowNum := GetInsertDataRowCount(data, t.GetSchema()) + rowNum, _ := GetInsertDataRowCount(data, t.GetSchema()) if rowNum == 0 { log.Info("0 row was imported, the data may have been deleted", WrapLogFields(t)...) continue @@ -200,6 +200,10 @@ func (t *ImportTask) importFile(reader importutilv2.Reader) error { if err != nil { return err } + err = AppendNullableDefaultFieldsData(t.GetSchema(), data, rowNum) + if err != nil { + return err + } if !importutilv2.IsBackup(t.req.GetOptions()) { err = RunEmbeddingFunction(t, data) if err != nil { diff --git a/internal/datanode/importv2/util.go b/internal/datanode/importv2/util.go index 61691689f6..cd7b2c7def 100644 --- a/internal/datanode/importv2/util.go +++ b/internal/datanode/importv2/util.go @@ -141,17 +141,9 @@ func CheckRowsEqual(schema *schemapb.CollectionSchema, data *storage.InsertData) return field.GetFieldID() }) - var field int64 - var rows int + rows, field := GetInsertDataRowCount(data, schema) for fieldID, d := range data.Data { - if idToField[fieldID].GetIsPrimaryKey() && idToField[fieldID].GetAutoID() { - continue - } - field, rows = fieldID, d.RowNum() - break - } - for fieldID, d := range data.Data { - if idToField[fieldID].GetIsPrimaryKey() && idToField[fieldID].GetAutoID() { + if d.RowNum() == 0 && (CanBeZeroRowField(idToField[fieldID])) { continue } if d.RowNum() != rows { @@ -201,6 +193,156 @@ func AppendSystemFieldsData(task *ImportTask, data *storage.InsertData, rowNum i return nil } +type nullDefaultAppender[T any] struct { +} + +func (h *nullDefaultAppender[T]) AppendDefault(fieldData storage.FieldData, defaultVal T, rowNum int) error { + values := make([]T, rowNum) + if fieldData.GetNullable() { + validData := make([]bool, rowNum) + for i := 0; i < rowNum; i++ { + validData[i] = true // all true + values[i] = defaultVal // fill with default value + } + return fieldData.AppendRows(values, validData) + } else { + for i := 0; i < rowNum; i++ { + values[i] = defaultVal // fill with default value + } + return fieldData.AppendDataRows(values) + } + return nil +} + +func (h *nullDefaultAppender[T]) AppendNull(fieldData storage.FieldData, rowNum int) error { + if fieldData.GetNullable() { + values := make([]T, rowNum) + validData := make([]bool, rowNum) + for i := 0; i < rowNum; i++ { + validData[i] = false + } + return fieldData.AppendRows(values, validData) + } + return nil +} + +func IsFillableField(field *schemapb.FieldSchema) bool { + nullable := field.GetNullable() + defaultVal := field.GetDefaultValue() + return nullable || defaultVal != nil +} + +func AppendNullableDefaultFieldsData(schema *schemapb.CollectionSchema, data *storage.InsertData, rowNum int) error { + for _, field := range schema.GetFields() { + if !IsFillableField(field) { + continue + } + if tempData, ok := data.Data[field.GetFieldID()]; ok { + if tempData.RowNum() > 0 { + continue // values have been read from data file + } + } + + // add a new column and fill with null or default + dataType := field.GetDataType() + fieldData, err := storage.NewFieldData(dataType, field, rowNum) + if err != nil { + return err + } + data.Data[field.GetFieldID()] = fieldData + + nullable := field.GetNullable() + defaultVal := field.GetDefaultValue() + + // bool/int8/int16/int32/int64/float/double/varchar/json/array can be null value + // bool/int8/int16/int32/int64/float/double/varchar can be default value + switch dataType { + case schemapb.DataType_Bool: + appender := &nullDefaultAppender[bool]{} + if defaultVal != nil { + v := defaultVal.GetBoolData() + err = appender.AppendDefault(fieldData, v, rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_Int8: + appender := &nullDefaultAppender[int8]{} + if defaultVal != nil { + v := defaultVal.GetIntData() + err = appender.AppendDefault(fieldData, int8(v), rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_Int16: + appender := &nullDefaultAppender[int16]{} + if defaultVal != nil { + v := defaultVal.GetIntData() + err = appender.AppendDefault(fieldData, int16(v), rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_Int32: + appender := &nullDefaultAppender[int32]{} + if defaultVal != nil { + v := defaultVal.GetIntData() + err = appender.AppendDefault(fieldData, int32(v), rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_Int64: + appender := &nullDefaultAppender[int64]{} + if defaultVal != nil { + v := defaultVal.GetLongData() + err = appender.AppendDefault(fieldData, v, rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_Float: + appender := &nullDefaultAppender[float32]{} + if defaultVal != nil { + v := defaultVal.GetFloatData() + err = appender.AppendDefault(fieldData, v, rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_Double: + appender := &nullDefaultAppender[float64]{} + if defaultVal != nil { + v := defaultVal.GetDoubleData() + err = appender.AppendDefault(fieldData, v, rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_VarChar: + appender := &nullDefaultAppender[string]{} + if defaultVal != nil { + v := defaultVal.GetStringData() + err = appender.AppendDefault(fieldData, v, rowNum) + } else if nullable { + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_JSON: + if nullable { + appender := &nullDefaultAppender[[]byte]{} + err = appender.AppendNull(fieldData, rowNum) + } + case schemapb.DataType_Array: + if nullable { + appender := &nullDefaultAppender[*schemapb.ScalarField]{} + err = appender.AppendNull(fieldData, rowNum) + } + default: + return fmt.Errorf("Unexpected data type: %d, cannot be filled with default value", dataType) + } + + if err != nil { + return err + } + } + + return nil +} + func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error { if err := RunBm25Function(task, data); err != nil { return err @@ -275,19 +417,34 @@ func RunBm25Function(task *ImportTask, data *storage.InsertData) error { return nil } -func GetInsertDataRowCount(data *storage.InsertData, schema *schemapb.CollectionSchema) int { +func CanBeZeroRowField(field *schemapb.FieldSchema) bool { + if field.GetIsPrimaryKey() && field.GetAutoID() { + return true // auto-generated primary key, the row count must be 0 + } + if field.GetIsDynamic() { + return true // dyanmic field, row count could be 0 + } + if IsFillableField(field) { + return true // nullable/default_value field can be automatically filled if the file doesn't contain this column + } + return false +} + +func GetInsertDataRowCount(data *storage.InsertData, schema *schemapb.CollectionSchema) (int, int64) { fields := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) int64 { return field.GetFieldID() }) for fieldID, fd := range data.Data { - if fields[fieldID].GetIsDynamic() { + if fd.RowNum() == 0 && CanBeZeroRowField(fields[fieldID]) { continue } + + // each collection must contains at least one vector field, there must be one field that row number is not 0 if fd.RowNum() != 0 { - return fd.RowNum() + return fd.RowNum(), fieldID } } - return 0 + return 0, 0 } func LogStats(manager TaskManager) { diff --git a/internal/datanode/importv2/util_test.go b/internal/datanode/importv2/util_test.go index 3dd4330dc4..f93b8092cc 100644 --- a/internal/datanode/importv2/util_test.go +++ b/internal/datanode/importv2/util_test.go @@ -17,6 +17,7 @@ package importv2 import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -24,6 +25,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/testutil" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" @@ -71,7 +73,7 @@ func Test_AppendSystemFieldsData(t *testing.T) { assert.Equal(t, 0, insertData.Data[pkField.GetFieldID()].RowNum()) assert.Nil(t, insertData.Data[common.RowIDField]) assert.Nil(t, insertData.Data[common.TimeStampField]) - rowNum := GetInsertDataRowCount(insertData, task.GetSchema()) + rowNum, _ := GetInsertDataRowCount(insertData, task.GetSchema()) err = AppendSystemFieldsData(task, insertData, rowNum) assert.NoError(t, err) assert.Equal(t, count, insertData.Data[pkField.GetFieldID()].RowNum()) @@ -85,7 +87,7 @@ func Test_AppendSystemFieldsData(t *testing.T) { assert.Equal(t, 0, insertData.Data[pkField.GetFieldID()].RowNum()) assert.Nil(t, insertData.Data[common.RowIDField]) assert.Nil(t, insertData.Data[common.TimeStampField]) - rowNum = GetInsertDataRowCount(insertData, task.GetSchema()) + rowNum, _ = GetInsertDataRowCount(insertData, task.GetSchema()) err = AppendSystemFieldsData(task, insertData, rowNum) assert.NoError(t, err) assert.Equal(t, count, insertData.Data[pkField.GetFieldID()].RowNum()) @@ -175,3 +177,283 @@ func Test_PickSegment(t *testing.T) { _, err := PickSegment(task.req.GetRequestSegments(), "ch-2", 20) assert.Error(t, err) } + +func Test_AppendNullableDefaultFieldsData(t *testing.T) { + buildSchemaFn := func() *schemapb.CollectionSchema { + fields := make([]*schemapb.FieldSchema, 0) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: 100, + Name: "pk", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + AutoID: false, + }) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: 101, + Name: "vec", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "4", + }, + }, + }) + fields = append(fields, &schemapb.FieldSchema{ + FieldID: 102, + Name: "dummy", + DataType: schemapb.DataType_Int32, + Nullable: true, + }) + + return &schemapb.CollectionSchema{ + Fields: fields, + } + } + + const count = 10 + tests := []struct { + name string + fieldID int64 + dataType schemapb.DataType + nullable bool + defaultVal *schemapb.ValueField + }{ + // nullable tests + { + name: "bool is nullable", + fieldID: 200, + dataType: schemapb.DataType_Bool, + nullable: true, + }, + { + name: "int8 is nullable", + fieldID: 200, + dataType: schemapb.DataType_Int8, + nullable: true, + }, + { + name: "int16 is nullable", + fieldID: 200, + dataType: schemapb.DataType_Int16, + nullable: true, + }, + { + name: "int32 is nullable", + fieldID: 200, + dataType: schemapb.DataType_Int32, + nullable: true, + }, + { + name: "int64 is nullable", + fieldID: 200, + dataType: schemapb.DataType_Int64, + nullable: true, + defaultVal: nil, + }, + { + name: "float is nullable", + fieldID: 200, + dataType: schemapb.DataType_Float, + nullable: true, + }, + { + name: "double is nullable", + fieldID: 200, + dataType: schemapb.DataType_Double, + nullable: true, + }, + { + name: "varchar is nullable", + fieldID: 200, + dataType: schemapb.DataType_VarChar, + nullable: true, + }, + { + name: "json is nullable", + fieldID: 200, + dataType: schemapb.DataType_JSON, + nullable: true, + }, + { + name: "array is nullable", + fieldID: 200, + dataType: schemapb.DataType_Array, + nullable: true, + }, + + // default value tests + { + name: "bool is default", + fieldID: 200, + dataType: schemapb.DataType_Bool, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_BoolData{ + BoolData: true, + }, + }, + }, + { + name: "int8 is default", + fieldID: 200, + dataType: schemapb.DataType_Int8, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 99, + }, + }, + }, + { + name: "int16 is default", + fieldID: 200, + dataType: schemapb.DataType_Int16, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 99, + }, + }, + }, + { + name: "int32 is default", + fieldID: 200, + dataType: schemapb.DataType_Int32, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 99, + }, + }, + }, + { + name: "int64 is default", + fieldID: 200, + dataType: schemapb.DataType_Int64, + nullable: true, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_LongData{ + LongData: 99, + }, + }, + }, + { + name: "float is default", + fieldID: 200, + dataType: schemapb.DataType_Float, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_FloatData{ + FloatData: 99.99, + }, + }, + }, + { + name: "double is default", + fieldID: 200, + dataType: schemapb.DataType_Double, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_DoubleData{ + DoubleData: 99.99, + }, + }, + }, + { + name: "varchar is default", + fieldID: 200, + dataType: schemapb.DataType_VarChar, + defaultVal: &schemapb.ValueField{ + Data: &schemapb.ValueField_StringData{ + StringData: "hello world", + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + schema := buildSchemaFn() + fieldSchema := &schemapb.FieldSchema{ + FieldID: tt.fieldID, + Name: fmt.Sprintf("field_%d", tt.fieldID), + DataType: tt.dataType, + Nullable: tt.nullable, + DefaultValue: tt.defaultVal, + } + if tt.dataType == schemapb.DataType_Array { + fieldSchema.ElementType = schemapb.DataType_Int64 + fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: common.MaxCapacityKey, Value: "100"}) + } else if tt.dataType == schemapb.DataType_VarChar { + fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: common.MaxLengthKey, Value: "100"}) + } + + insertData, err := testutil.CreateInsertData(schema, count) + assert.NoError(t, err) + + schema.Fields = append(schema.Fields, fieldSchema) + + fieldData, err := storage.NewFieldData(fieldSchema.GetDataType(), fieldSchema, 0) + assert.NoError(t, err) + insertData.Data[fieldSchema.GetFieldID()] = fieldData + + err = AppendNullableDefaultFieldsData(schema, insertData, count) + assert.NoError(t, err) + + for fieldID, fieldData := range insertData.Data { + if fieldID < int64(200) { + continue + } + assert.Equal(t, count, fieldData.RowNum()) + + if tt.nullable { + assert.True(t, fieldData.GetNullable()) + } + + if tt.defaultVal != nil { + switch tt.dataType { + case schemapb.DataType_Bool: + tempFieldData := fieldData.(*storage.BoolFieldData) + for _, v := range tempFieldData.Data { + assert.True(t, v) + } + case schemapb.DataType_Int8: + tempFieldData := fieldData.(*storage.Int8FieldData) + for _, v := range tempFieldData.Data { + assert.Equal(t, int8(99), v) + } + case schemapb.DataType_Int16: + tempFieldData := fieldData.(*storage.Int16FieldData) + for _, v := range tempFieldData.Data { + assert.Equal(t, int16(99), v) + } + case schemapb.DataType_Int32: + tempFieldData := fieldData.(*storage.Int32FieldData) + for _, v := range tempFieldData.Data { + assert.Equal(t, int32(99), v) + } + case schemapb.DataType_Int64: + tempFieldData := fieldData.(*storage.Int64FieldData) + for _, v := range tempFieldData.Data { + assert.Equal(t, int64(99), v) + } + case schemapb.DataType_Float: + tempFieldData := fieldData.(*storage.FloatFieldData) + for _, v := range tempFieldData.Data { + assert.Equal(t, float32(99.99), v) + } + case schemapb.DataType_Double: + tempFieldData := fieldData.(*storage.DoubleFieldData) + for _, v := range tempFieldData.Data { + assert.Equal(t, float64(99.99), v) + } + case schemapb.DataType_VarChar: + tempFieldData := fieldData.(*storage.StringFieldData) + for _, v := range tempFieldData.Data { + assert.Equal(t, "hello world", v) + } + default: + } + } else if tt.nullable { + for i := 0; i < count; i++ { + assert.Nil(t, fieldData.GetRow(i)) + } + } + } + }) + } +} diff --git a/internal/util/importutilv2/parquet/util.go b/internal/util/importutilv2/parquet/util.go index d046981e15..a60e4d9b08 100644 --- a/internal/util/importutilv2/parquet/util.go +++ b/internal/util/importutilv2/parquet/util.go @@ -67,6 +67,10 @@ func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, sch return nil, merr.WrapErrImportFailed( fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", field.GetName())) } + if field.GetIsFunctionOutput() { + return nil, merr.WrapErrImportFailed( + fmt.Sprintf("the field '%s' is output by function, no need to provide", field.GetName())) + } cr, err := NewFieldReader(ctx, fileReader, i, field) if err != nil { @@ -80,7 +84,8 @@ func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, sch } for _, field := range nameToField { - if typeutil.IsAutoPKField(field) || field.GetIsDynamic() || field.GetIsFunctionOutput() { + if typeutil.IsAutoPKField(field) || field.GetIsDynamic() || field.GetIsFunctionOutput() || + field.GetNullable() || field.GetDefaultValue() != nil { continue } if _, ok := crs[field.GetFieldID()]; !ok { @@ -285,12 +290,17 @@ func isSchemaEqual(schema *schemapb.CollectionSchema, arrSchema *arrow.Schema) e return field.Name }) for _, field := range schema.GetFields() { + // ignore autoPKField and functionOutputField if typeutil.IsAutoPKField(field) || field.GetIsFunctionOutput() { continue } arrField, ok := arrNameToField[field.GetName()] if !ok { - if field.GetIsDynamic() { + // Special fields no need to provide in data files, the parquet file doesn't contain this field, no need to compare + // 1. dynamic field(name is "$meta"), ignore + // 2. nullable field, filled with null values + // 3. default value field, filled with default value + if field.GetIsDynamic() || field.GetNullable() || field.GetDefaultValue() != nil { continue } return merr.WrapErrImportFailed(fmt.Sprintf("field '%s' not in arrow schema", field.GetName()))