diff --git a/internal/storage/arrow_util.go b/internal/storage/arrow_util.go index 49f0e5270b..dc3d166941 100644 --- a/internal/storage/arrow_util.go +++ b/internal/storage/arrow_util.go @@ -22,8 +22,10 @@ import ( "github.com/apache/arrow/go/v17/arrow" "github.com/apache/arrow/go/v17/arrow/array" "github.com/apache/arrow/go/v17/arrow/memory" + "github.com/samber/lo" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -235,6 +237,69 @@ func appendValueAt(builder array.Builder, a arrow.Array, idx int, defaultValue * } } +// GenerateEmptyArrayFromSchema generate empty array from schema +// If schema has default value, the array will bef filled with it. +// Otherwise, null will be used instead. +// If input schema is not nullable, an error will be returned. +func GenerateEmptyArrayFromSchema(schema *schemapb.FieldSchema, numRows int) (arrow.Array, error) { + // if not nullable, return error + if !schema.GetNullable() { + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("missing field data %s", schema.Name)) + } + dim, _ := typeutil.GetDim(schema) + builder := array.NewBuilder(memory.DefaultAllocator, serdeMap[schema.GetDataType()].arrowType(int(dim))) // serdeEntry[schema.GetDataType()].newBuilder() + if schema.GetDefaultValue() != nil { + switch schema.GetDataType() { + case schemapb.DataType_Bool: + bd := builder.(*array.BooleanBuilder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) bool { return schema.GetDefaultValue().GetBoolData() }), + nil) + case schemapb.DataType_Int8: + bd := builder.(*array.Int8Builder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) int8 { return int8(schema.GetDefaultValue().GetIntData()) }), + nil) + case schemapb.DataType_Int16: + bd := builder.(*array.Int16Builder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) int16 { return int16(schema.GetDefaultValue().GetIntData()) }), + nil) + case schemapb.DataType_Int32: + bd := builder.(*array.Int32Builder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) int32 { return schema.GetDefaultValue().GetIntData() }), + nil) + case schemapb.DataType_Int64: + bd := builder.(*array.Int64Builder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) int64 { return schema.GetDefaultValue().GetLongData() }), + nil) + case schemapb.DataType_Float: + bd := builder.(*array.Float32Builder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) float32 { return schema.GetDefaultValue().GetFloatData() }), + nil) + case schemapb.DataType_Double: + bd := builder.(*array.Float64Builder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) float64 { return schema.GetDefaultValue().GetDoubleData() }), + nil) + case schemapb.DataType_VarChar, schemapb.DataType_String: + bd := builder.(*array.StringBuilder) + bd.AppendValues( + lo.RepeatBy(numRows, func(_ int) string { return schema.GetDefaultValue().GetStringData() }), + nil) + default: + return nil, merr.WrapErrServiceInternal(fmt.Sprintf("Unexpected default value type: %s", schema.GetDataType().String())) + } + } else { + builder.AppendNulls(numRows) + } + + return builder.NewArray(), nil +} + // RecordBuilder is a helper to build arrow record. // Due to current arrow impl (v12), the write performance is largely dependent on the batch size, // small batch size will cause write performance degradation. To work around this issue, we accumulate diff --git a/internal/storage/arrow_util_test.go b/internal/storage/arrow_util_test.go new file mode 100644 index 0000000000..ecaefea071 --- /dev/null +++ b/internal/storage/arrow_util_test.go @@ -0,0 +1,220 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "math/rand" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestGenerateEmptyArray(t *testing.T) { + type testCase struct { + tag string + field *schemapb.FieldSchema + expectErr bool + expectNull bool + expectValue any + } + + cases := []testCase{ + { + tag: "no_default_value", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int8, + Nullable: true, + }, + expectErr: false, + expectNull: true, + }, + { + tag: "int8", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int8, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 10, + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: int8(10), + }, + { + tag: "int16", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int16, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 16, + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: int16(16), + }, + { + tag: "int32", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int32, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 32, + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: int32(32), + }, + { + tag: "int64", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int64, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_LongData{ + LongData: 64, + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: int64(64), + }, + { + tag: "bool", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Bool, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_BoolData{ + BoolData: true, + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: true, + }, + { + tag: "float", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Float, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_FloatData{ + FloatData: 0.1, + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: float32(0.1), + }, + { + tag: "double", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Double, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_DoubleData{ + DoubleData: 1.2, + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: float64(1.2), + }, + { + tag: "varchar", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_VarChar, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_StringData{ + StringData: "varchar", + }, + }, + }, + expectErr: false, + expectNull: false, + expectValue: "varchar", + }, + { + tag: "invalid_schema_datatype", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_FloatVector, + Nullable: true, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 10, + }, + }, + }, + expectErr: true, + }, + { + tag: "invalid_schema_nullable", + field: &schemapb.FieldSchema{ + DataType: schemapb.DataType_Int8, + Nullable: false, + DefaultValue: &schemapb.ValueField{ + Data: &schemapb.ValueField_IntData{ + IntData: 10, + }, + }, + }, + expectErr: true, + }, + } + + for _, tc := range cases { + t.Run(tc.tag, func(t *testing.T) { + rowNum := rand.Intn(100) + 1 + a, err := GenerateEmptyArrayFromSchema(tc.field, rowNum) + switch { + case tc.expectErr: + assert.Error(t, err) + case tc.expectNull: + assert.NoError(t, err) + assert.EqualValues(t, rowNum, a.Len()) + for i := range rowNum { + assert.True(t, a.IsNull(i)) + } + default: + assert.NoError(t, err) + assert.EqualValues(t, rowNum, a.Len()) + for i := range rowNum { + value, ok := serdeMap[tc.field.DataType].deserialize(a, i) + assert.True(t, a.IsValid(i)) + assert.True(t, ok) + assert.Equal(t, tc.expectValue, value) + } + } + }) + } +} diff --git a/internal/storage/serde_events.go b/internal/storage/serde_events.go index aea99065b6..ed5b5c1f3d 100644 --- a/internal/storage/serde_events.go +++ b/internal/storage/serde_events.go @@ -122,13 +122,12 @@ func (crr *CompositeBinlogRecordReader) Next() (Record, error) { // If the field is not in the current batch, fill with null array // Note that we're intentionally not filling default value here, because the // deserializer will fill them later. - if !f.Nullable { - return nil, merr.WrapErrServiceInternal(fmt.Sprintf("missing field data %s", f.Name)) + numRows := int(crr.rrs[0].Record().NumRows()) + arr, err := GenerateEmptyArrayFromSchema(f, numRows) + if err != nil { + return nil, err } - dim, _ := typeutil.GetDim(f) - builder := array.NewBuilder(memory.DefaultAllocator, serdeMap[f.DataType].arrowType(int(dim))) - builder.AppendNulls(int(crr.rrs[0].Record().NumRows())) - recs[i] = builder.NewArray() + recs[i] = arr } } return &compositeRecord{