enhance: bulkinsert handles nullable/default (#42127)

issue: https://github.com/milvus-io/milvus/issues/42096,
https://github.com/milvus-io/milvus/issues/42130

Signed-off-by: yhmo <yihua.mo@zilliz.com>
This commit is contained in:
groot 2025-05-28 18:02:28 +08:00 committed by GitHub
parent 79b51cbb73
commit 14563ad2b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 502 additions and 24 deletions

View File

@ -104,6 +104,30 @@ func HashDeleteData(task Task, delData *storage.DeleteData) ([]*storage.DeleteDa
return res, nil
}
// this method is only for GetRowsStats() to get a row from storage.InsertData
// the GetRowsStats() is called by PreImportTask, some of nullable/default_value fields in the storage.InsertData could be zero row
func getRowFromInsertData(rows *storage.InsertData, i int) map[int64]interface{} {
res := make(map[int64]interface{})
for field, data := range rows.Data {
if data.RowNum() > i {
res[field] = data.GetRow(i)
}
}
return res
}
// this method is only for GetRowsStats() to get a row from storage.InsertData
// the GetRowsStats() is called by PreImportTask, some of nullable/default_value fields in the storage.InsertData could be zero row
func getRowSizeFromInsertData(rows *storage.InsertData, i int) int {
size := 0
for _, data := range rows.Data {
if data.RowNum() > i {
size += data.GetRowSize(i)
}
}
return size
}
func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.PartitionImportStats, error) {
var (
schema = task.GetSchema()
@ -127,7 +151,7 @@ func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.Parti
hashDataSize[i] = make([]int, partitionNum)
}
rowNum := GetInsertDataRowCount(rows, schema)
rowNum, _ := GetInsertDataRowCount(rows, schema)
if pkField.GetAutoID() {
fn := hashByPartition(int64(partitionNum), partKeyField)
rows.Data = lo.PickBy(rows.Data, func(fieldID int64, _ storage.FieldData) bool {
@ -136,9 +160,10 @@ func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.Parti
hashByPartRowsCount := make([]int, partitionNum)
hashByPartDataSize := make([]int, partitionNum)
for i := 0; i < rowNum; i++ {
p := fn(rows.GetRow(i)[id2])
row := getRowFromInsertData(rows, i)
p := fn(row[id2])
hashByPartRowsCount[p]++
hashByPartDataSize[p] += rows.GetRowSize(i)
hashByPartDataSize[p] += getRowSizeFromInsertData(rows, i)
}
// When autoID is enabled, the generated IDs will be evenly hashed across all channels.
// Therefore, here we just assign an average number of rows to each channel.
@ -152,10 +177,10 @@ func GetRowsStats(task Task, rows *storage.InsertData) (map[string]*datapb.Parti
f1 := hashByVChannel(int64(channelNum), pkField)
f2 := hashByPartition(int64(partitionNum), partKeyField)
for i := 0; i < rowNum; i++ {
row := rows.GetRow(i)
row := getRowFromInsertData(rows, i)
p1, p2 := f1(row[id1]), f2(row[id2])
hashRowsCount[p1][p2]++
hashDataSize[p1][p2] += rows.GetRowSize(i)
hashDataSize[p1][p2] += getRowSizeFromInsertData(rows, i)
}
}

View File

@ -191,7 +191,7 @@ func (t *ImportTask) importFile(reader importutilv2.Reader) error {
}
return err
}
rowNum := GetInsertDataRowCount(data, t.GetSchema())
rowNum, _ := GetInsertDataRowCount(data, t.GetSchema())
if rowNum == 0 {
log.Info("0 row was imported, the data may have been deleted", WrapLogFields(t)...)
continue
@ -200,6 +200,10 @@ func (t *ImportTask) importFile(reader importutilv2.Reader) error {
if err != nil {
return err
}
err = AppendNullableDefaultFieldsData(t.GetSchema(), data, rowNum)
if err != nil {
return err
}
if !importutilv2.IsBackup(t.req.GetOptions()) {
err = RunEmbeddingFunction(t, data)
if err != nil {

View File

@ -141,17 +141,9 @@ func CheckRowsEqual(schema *schemapb.CollectionSchema, data *storage.InsertData)
return field.GetFieldID()
})
var field int64
var rows int
rows, field := GetInsertDataRowCount(data, schema)
for fieldID, d := range data.Data {
if idToField[fieldID].GetIsPrimaryKey() && idToField[fieldID].GetAutoID() {
continue
}
field, rows = fieldID, d.RowNum()
break
}
for fieldID, d := range data.Data {
if idToField[fieldID].GetIsPrimaryKey() && idToField[fieldID].GetAutoID() {
if d.RowNum() == 0 && (CanBeZeroRowField(idToField[fieldID])) {
continue
}
if d.RowNum() != rows {
@ -201,6 +193,156 @@ func AppendSystemFieldsData(task *ImportTask, data *storage.InsertData, rowNum i
return nil
}
type nullDefaultAppender[T any] struct {
}
func (h *nullDefaultAppender[T]) AppendDefault(fieldData storage.FieldData, defaultVal T, rowNum int) error {
values := make([]T, rowNum)
if fieldData.GetNullable() {
validData := make([]bool, rowNum)
for i := 0; i < rowNum; i++ {
validData[i] = true // all true
values[i] = defaultVal // fill with default value
}
return fieldData.AppendRows(values, validData)
} else {
for i := 0; i < rowNum; i++ {
values[i] = defaultVal // fill with default value
}
return fieldData.AppendDataRows(values)
}
return nil
}
func (h *nullDefaultAppender[T]) AppendNull(fieldData storage.FieldData, rowNum int) error {
if fieldData.GetNullable() {
values := make([]T, rowNum)
validData := make([]bool, rowNum)
for i := 0; i < rowNum; i++ {
validData[i] = false
}
return fieldData.AppendRows(values, validData)
}
return nil
}
func IsFillableField(field *schemapb.FieldSchema) bool {
nullable := field.GetNullable()
defaultVal := field.GetDefaultValue()
return nullable || defaultVal != nil
}
func AppendNullableDefaultFieldsData(schema *schemapb.CollectionSchema, data *storage.InsertData, rowNum int) error {
for _, field := range schema.GetFields() {
if !IsFillableField(field) {
continue
}
if tempData, ok := data.Data[field.GetFieldID()]; ok {
if tempData.RowNum() > 0 {
continue // values have been read from data file
}
}
// add a new column and fill with null or default
dataType := field.GetDataType()
fieldData, err := storage.NewFieldData(dataType, field, rowNum)
if err != nil {
return err
}
data.Data[field.GetFieldID()] = fieldData
nullable := field.GetNullable()
defaultVal := field.GetDefaultValue()
// bool/int8/int16/int32/int64/float/double/varchar/json/array can be null value
// bool/int8/int16/int32/int64/float/double/varchar can be default value
switch dataType {
case schemapb.DataType_Bool:
appender := &nullDefaultAppender[bool]{}
if defaultVal != nil {
v := defaultVal.GetBoolData()
err = appender.AppendDefault(fieldData, v, rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_Int8:
appender := &nullDefaultAppender[int8]{}
if defaultVal != nil {
v := defaultVal.GetIntData()
err = appender.AppendDefault(fieldData, int8(v), rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_Int16:
appender := &nullDefaultAppender[int16]{}
if defaultVal != nil {
v := defaultVal.GetIntData()
err = appender.AppendDefault(fieldData, int16(v), rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_Int32:
appender := &nullDefaultAppender[int32]{}
if defaultVal != nil {
v := defaultVal.GetIntData()
err = appender.AppendDefault(fieldData, int32(v), rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_Int64:
appender := &nullDefaultAppender[int64]{}
if defaultVal != nil {
v := defaultVal.GetLongData()
err = appender.AppendDefault(fieldData, v, rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_Float:
appender := &nullDefaultAppender[float32]{}
if defaultVal != nil {
v := defaultVal.GetFloatData()
err = appender.AppendDefault(fieldData, v, rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_Double:
appender := &nullDefaultAppender[float64]{}
if defaultVal != nil {
v := defaultVal.GetDoubleData()
err = appender.AppendDefault(fieldData, v, rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_VarChar:
appender := &nullDefaultAppender[string]{}
if defaultVal != nil {
v := defaultVal.GetStringData()
err = appender.AppendDefault(fieldData, v, rowNum)
} else if nullable {
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_JSON:
if nullable {
appender := &nullDefaultAppender[[]byte]{}
err = appender.AppendNull(fieldData, rowNum)
}
case schemapb.DataType_Array:
if nullable {
appender := &nullDefaultAppender[*schemapb.ScalarField]{}
err = appender.AppendNull(fieldData, rowNum)
}
default:
return fmt.Errorf("Unexpected data type: %d, cannot be filled with default value", dataType)
}
if err != nil {
return err
}
}
return nil
}
func RunEmbeddingFunction(task *ImportTask, data *storage.InsertData) error {
if err := RunBm25Function(task, data); err != nil {
return err
@ -275,19 +417,34 @@ func RunBm25Function(task *ImportTask, data *storage.InsertData) error {
return nil
}
func GetInsertDataRowCount(data *storage.InsertData, schema *schemapb.CollectionSchema) int {
func CanBeZeroRowField(field *schemapb.FieldSchema) bool {
if field.GetIsPrimaryKey() && field.GetAutoID() {
return true // auto-generated primary key, the row count must be 0
}
if field.GetIsDynamic() {
return true // dyanmic field, row count could be 0
}
if IsFillableField(field) {
return true // nullable/default_value field can be automatically filled if the file doesn't contain this column
}
return false
}
func GetInsertDataRowCount(data *storage.InsertData, schema *schemapb.CollectionSchema) (int, int64) {
fields := lo.KeyBy(schema.GetFields(), func(field *schemapb.FieldSchema) int64 {
return field.GetFieldID()
})
for fieldID, fd := range data.Data {
if fields[fieldID].GetIsDynamic() {
if fd.RowNum() == 0 && CanBeZeroRowField(fields[fieldID]) {
continue
}
// each collection must contains at least one vector field, there must be one field that row number is not 0
if fd.RowNum() != 0 {
return fd.RowNum()
return fd.RowNum(), fieldID
}
}
return 0
return 0, 0
}
func LogStats(manager TaskManager) {

View File

@ -17,6 +17,7 @@
package importv2
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
@ -24,6 +25,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util/testutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
@ -71,7 +73,7 @@ func Test_AppendSystemFieldsData(t *testing.T) {
assert.Equal(t, 0, insertData.Data[pkField.GetFieldID()].RowNum())
assert.Nil(t, insertData.Data[common.RowIDField])
assert.Nil(t, insertData.Data[common.TimeStampField])
rowNum := GetInsertDataRowCount(insertData, task.GetSchema())
rowNum, _ := GetInsertDataRowCount(insertData, task.GetSchema())
err = AppendSystemFieldsData(task, insertData, rowNum)
assert.NoError(t, err)
assert.Equal(t, count, insertData.Data[pkField.GetFieldID()].RowNum())
@ -85,7 +87,7 @@ func Test_AppendSystemFieldsData(t *testing.T) {
assert.Equal(t, 0, insertData.Data[pkField.GetFieldID()].RowNum())
assert.Nil(t, insertData.Data[common.RowIDField])
assert.Nil(t, insertData.Data[common.TimeStampField])
rowNum = GetInsertDataRowCount(insertData, task.GetSchema())
rowNum, _ = GetInsertDataRowCount(insertData, task.GetSchema())
err = AppendSystemFieldsData(task, insertData, rowNum)
assert.NoError(t, err)
assert.Equal(t, count, insertData.Data[pkField.GetFieldID()].RowNum())
@ -175,3 +177,283 @@ func Test_PickSegment(t *testing.T) {
_, err := PickSegment(task.req.GetRequestSegments(), "ch-2", 20)
assert.Error(t, err)
}
func Test_AppendNullableDefaultFieldsData(t *testing.T) {
buildSchemaFn := func() *schemapb.CollectionSchema {
fields := make([]*schemapb.FieldSchema, 0)
fields = append(fields, &schemapb.FieldSchema{
FieldID: 100,
Name: "pk",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
AutoID: false,
})
fields = append(fields, &schemapb.FieldSchema{
FieldID: 101,
Name: "vec",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "4",
},
},
})
fields = append(fields, &schemapb.FieldSchema{
FieldID: 102,
Name: "dummy",
DataType: schemapb.DataType_Int32,
Nullable: true,
})
return &schemapb.CollectionSchema{
Fields: fields,
}
}
const count = 10
tests := []struct {
name string
fieldID int64
dataType schemapb.DataType
nullable bool
defaultVal *schemapb.ValueField
}{
// nullable tests
{
name: "bool is nullable",
fieldID: 200,
dataType: schemapb.DataType_Bool,
nullable: true,
},
{
name: "int8 is nullable",
fieldID: 200,
dataType: schemapb.DataType_Int8,
nullable: true,
},
{
name: "int16 is nullable",
fieldID: 200,
dataType: schemapb.DataType_Int16,
nullable: true,
},
{
name: "int32 is nullable",
fieldID: 200,
dataType: schemapb.DataType_Int32,
nullable: true,
},
{
name: "int64 is nullable",
fieldID: 200,
dataType: schemapb.DataType_Int64,
nullable: true,
defaultVal: nil,
},
{
name: "float is nullable",
fieldID: 200,
dataType: schemapb.DataType_Float,
nullable: true,
},
{
name: "double is nullable",
fieldID: 200,
dataType: schemapb.DataType_Double,
nullable: true,
},
{
name: "varchar is nullable",
fieldID: 200,
dataType: schemapb.DataType_VarChar,
nullable: true,
},
{
name: "json is nullable",
fieldID: 200,
dataType: schemapb.DataType_JSON,
nullable: true,
},
{
name: "array is nullable",
fieldID: 200,
dataType: schemapb.DataType_Array,
nullable: true,
},
// default value tests
{
name: "bool is default",
fieldID: 200,
dataType: schemapb.DataType_Bool,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_BoolData{
BoolData: true,
},
},
},
{
name: "int8 is default",
fieldID: 200,
dataType: schemapb.DataType_Int8,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_IntData{
IntData: 99,
},
},
},
{
name: "int16 is default",
fieldID: 200,
dataType: schemapb.DataType_Int16,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_IntData{
IntData: 99,
},
},
},
{
name: "int32 is default",
fieldID: 200,
dataType: schemapb.DataType_Int32,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_IntData{
IntData: 99,
},
},
},
{
name: "int64 is default",
fieldID: 200,
dataType: schemapb.DataType_Int64,
nullable: true,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_LongData{
LongData: 99,
},
},
},
{
name: "float is default",
fieldID: 200,
dataType: schemapb.DataType_Float,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_FloatData{
FloatData: 99.99,
},
},
},
{
name: "double is default",
fieldID: 200,
dataType: schemapb.DataType_Double,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_DoubleData{
DoubleData: 99.99,
},
},
},
{
name: "varchar is default",
fieldID: 200,
dataType: schemapb.DataType_VarChar,
defaultVal: &schemapb.ValueField{
Data: &schemapb.ValueField_StringData{
StringData: "hello world",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema := buildSchemaFn()
fieldSchema := &schemapb.FieldSchema{
FieldID: tt.fieldID,
Name: fmt.Sprintf("field_%d", tt.fieldID),
DataType: tt.dataType,
Nullable: tt.nullable,
DefaultValue: tt.defaultVal,
}
if tt.dataType == schemapb.DataType_Array {
fieldSchema.ElementType = schemapb.DataType_Int64
fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: common.MaxCapacityKey, Value: "100"})
} else if tt.dataType == schemapb.DataType_VarChar {
fieldSchema.TypeParams = append(fieldSchema.TypeParams, &commonpb.KeyValuePair{Key: common.MaxLengthKey, Value: "100"})
}
insertData, err := testutil.CreateInsertData(schema, count)
assert.NoError(t, err)
schema.Fields = append(schema.Fields, fieldSchema)
fieldData, err := storage.NewFieldData(fieldSchema.GetDataType(), fieldSchema, 0)
assert.NoError(t, err)
insertData.Data[fieldSchema.GetFieldID()] = fieldData
err = AppendNullableDefaultFieldsData(schema, insertData, count)
assert.NoError(t, err)
for fieldID, fieldData := range insertData.Data {
if fieldID < int64(200) {
continue
}
assert.Equal(t, count, fieldData.RowNum())
if tt.nullable {
assert.True(t, fieldData.GetNullable())
}
if tt.defaultVal != nil {
switch tt.dataType {
case schemapb.DataType_Bool:
tempFieldData := fieldData.(*storage.BoolFieldData)
for _, v := range tempFieldData.Data {
assert.True(t, v)
}
case schemapb.DataType_Int8:
tempFieldData := fieldData.(*storage.Int8FieldData)
for _, v := range tempFieldData.Data {
assert.Equal(t, int8(99), v)
}
case schemapb.DataType_Int16:
tempFieldData := fieldData.(*storage.Int16FieldData)
for _, v := range tempFieldData.Data {
assert.Equal(t, int16(99), v)
}
case schemapb.DataType_Int32:
tempFieldData := fieldData.(*storage.Int32FieldData)
for _, v := range tempFieldData.Data {
assert.Equal(t, int32(99), v)
}
case schemapb.DataType_Int64:
tempFieldData := fieldData.(*storage.Int64FieldData)
for _, v := range tempFieldData.Data {
assert.Equal(t, int64(99), v)
}
case schemapb.DataType_Float:
tempFieldData := fieldData.(*storage.FloatFieldData)
for _, v := range tempFieldData.Data {
assert.Equal(t, float32(99.99), v)
}
case schemapb.DataType_Double:
tempFieldData := fieldData.(*storage.DoubleFieldData)
for _, v := range tempFieldData.Data {
assert.Equal(t, float64(99.99), v)
}
case schemapb.DataType_VarChar:
tempFieldData := fieldData.(*storage.StringFieldData)
for _, v := range tempFieldData.Data {
assert.Equal(t, "hello world", v)
}
default:
}
} else if tt.nullable {
for i := 0; i < count; i++ {
assert.Nil(t, fieldData.GetRow(i))
}
}
}
})
}
}

View File

@ -67,6 +67,10 @@ func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, sch
return nil, merr.WrapErrImportFailed(
fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", field.GetName()))
}
if field.GetIsFunctionOutput() {
return nil, merr.WrapErrImportFailed(
fmt.Sprintf("the field '%s' is output by function, no need to provide", field.GetName()))
}
cr, err := NewFieldReader(ctx, fileReader, i, field)
if err != nil {
@ -80,7 +84,8 @@ func CreateFieldReaders(ctx context.Context, fileReader *pqarrow.FileReader, sch
}
for _, field := range nameToField {
if typeutil.IsAutoPKField(field) || field.GetIsDynamic() || field.GetIsFunctionOutput() {
if typeutil.IsAutoPKField(field) || field.GetIsDynamic() || field.GetIsFunctionOutput() ||
field.GetNullable() || field.GetDefaultValue() != nil {
continue
}
if _, ok := crs[field.GetFieldID()]; !ok {
@ -285,12 +290,17 @@ func isSchemaEqual(schema *schemapb.CollectionSchema, arrSchema *arrow.Schema) e
return field.Name
})
for _, field := range schema.GetFields() {
// ignore autoPKField and functionOutputField
if typeutil.IsAutoPKField(field) || field.GetIsFunctionOutput() {
continue
}
arrField, ok := arrNameToField[field.GetName()]
if !ok {
if field.GetIsDynamic() {
// Special fields no need to provide in data files, the parquet file doesn't contain this field, no need to compare
// 1. dynamic field(name is "$meta"), ignore
// 2. nullable field, filled with null values
// 3. default value field, filled with default value
if field.GetIsDynamic() || field.GetNullable() || field.GetDefaultValue() != nil {
continue
}
return merr.WrapErrImportFailed(fmt.Sprintf("field '%s' not in arrow schema", field.GetName()))