diff --git a/internal/rootcoord/import_manager.go b/internal/rootcoord/import_manager.go index f7f4085a25..0e2c64f1da 100644 --- a/internal/rootcoord/import_manager.go +++ b/internal/rootcoord/import_manager.go @@ -399,18 +399,18 @@ func (m *importManager) isRowbased(files []string) (bool, error) { isRowBased := false for _, filePath := range files { _, fileType := importutil.GetFileNameAndExt(filePath) - if fileType == importutil.JSONFileExt { + if fileType == importutil.JSONFileExt || fileType == importutil.CSVFileExt { isRowBased = true } else if isRowBased { - log.Error("row-based data file type must be JSON, mixed file types is not allowed", zap.Strings("files", files)) - return isRowBased, fmt.Errorf("row-based data file type must be JSON, file type '%s' is not allowed", fileType) + log.Error("row-based data file type must be JSON or CSV, mixed file types is not allowed", zap.Strings("files", files)) + return isRowBased, fmt.Errorf("row-based data file type must be JSON or CSV, file type '%s' is not allowed", fileType) } } // for row_based, we only allow one file so that each invocation only generate a task if isRowBased && len(files) > 1 { - log.Error("row-based import, only allow one JSON file each time", zap.Strings("files", files)) - return isRowBased, fmt.Errorf("row-based import, only allow one JSON file each time") + log.Error("row-based import, only allow one JSON or CSV file each time", zap.Strings("files", files)) + return isRowBased, fmt.Errorf("row-based import, only allow one JSON or CSV file each time") } return isRowBased, nil diff --git a/internal/rootcoord/import_manager_test.go b/internal/rootcoord/import_manager_test.go index fe94b2db66..04a6ea0d2e 100644 --- a/internal/rootcoord/import_manager_test.go +++ b/internal/rootcoord/import_manager_test.go @@ -1101,6 +1101,26 @@ func TestImportManager_isRowbased(t *testing.T) { rb, err = mgr.isRowbased(files) assert.NoError(t, err) assert.False(t, rb) + + files = []string{"1.csv"} + rb, err = mgr.isRowbased(files) + assert.NoError(t, err) + assert.True(t, rb) + + files = []string{"1.csv", "2.csv"} + rb, err = mgr.isRowbased(files) + assert.Error(t, err) + assert.True(t, rb) + + files = []string{"1.csv", "2.json"} + rb, err = mgr.isRowbased(files) + assert.Error(t, err) + assert.True(t, rb) + + files = []string{"1.csv", "2.npy"} + rb, err = mgr.isRowbased(files) + assert.Error(t, err) + assert.True(t, rb) } func TestImportManager_mergeArray(t *testing.T) { diff --git a/internal/util/importutil/csv_handler.go b/internal/util/importutil/csv_handler.go new file mode 100644 index 0000000000..f06596f2cb --- /dev/null +++ b/internal/util/importutil/csv_handler.go @@ -0,0 +1,446 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type CSVRowHandler interface { + Handle(row []map[storage.FieldID]string) error +} + +// CSVRowConsumer is row-based csv format consumer class +type CSVRowConsumer struct { + ctx context.Context // for canceling parse process + collectionInfo *CollectionInfo // collection details including schema + rowIDAllocator *allocator.IDAllocator // autoid allocator + validators map[storage.FieldID]*CSVValidator // validators for each field + rowCounter int64 // how many rows have been consumed + shardsData []ShardData // in-memory shards data + blockSize int64 // maximum size of a read block(unit:byte) + autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25 + + callFlushFunc ImportFlushFunc // call back function to flush segment +} + +func NewCSVRowConsumer(ctx context.Context, + collectionInfo *CollectionInfo, + idAlloc *allocator.IDAllocator, + blockSize int64, + flushFunc ImportFlushFunc, +) (*CSVRowConsumer, error) { + if collectionInfo == nil { + log.Warn("CSV row consumer: collection schema is nil") + return nil, errors.New("collection schema is nil") + } + + v := &CSVRowConsumer{ + ctx: ctx, + collectionInfo: collectionInfo, + rowIDAllocator: idAlloc, + validators: make(map[storage.FieldID]*CSVValidator, 0), + rowCounter: 0, + shardsData: make([]ShardData, 0, collectionInfo.ShardNum), + blockSize: blockSize, + autoIDRange: make([]int64, 0), + callFlushFunc: flushFunc, + } + + if err := v.initValidators(collectionInfo.Schema); err != nil { + log.Warn("CSV row consumer: fail to initialize csv row-based consumer", zap.Error(err)) + return nil, fmt.Errorf("fail to initialize csv row-based consumer, error: %w", err) + } + + for i := 0; i < int(collectionInfo.ShardNum); i++ { + shardData := initShardData(collectionInfo.Schema, collectionInfo.PartitionIDs) + if shardData == nil { + log.Warn("CSV row consumer: fail to initialize in-memory segment data", zap.Int("shardID", i)) + return nil, fmt.Errorf("fail to initialize in-memory segment data for shard id %d", i) + } + v.shardsData = append(v.shardsData, shardData) + } + + // primary key is autoid, id generator is required + if v.collectionInfo.PrimaryKey.GetAutoID() && idAlloc == nil { + log.Warn("CSV row consumer: ID allocator is nil") + return nil, errors.New("ID allocator is nil") + } + + return v, nil +} + +type CSVValidator struct { + convertFunc func(val string, field storage.FieldData) error // convert data function + isString bool // for string field + fieldName string // field name +} + +func (v *CSVRowConsumer) initValidators(collectionSchema *schemapb.CollectionSchema) error { + if collectionSchema == nil { + return errors.New("collection schema is nil") + } + + validators := v.validators + + for i := 0; i < len(collectionSchema.Fields); i++ { + schema := collectionSchema.Fields[i] + + validators[schema.GetFieldID()] = &CSVValidator{} + validators[schema.GetFieldID()].fieldName = schema.GetName() + validators[schema.GetFieldID()].isString = false + + switch schema.DataType { + // all obj is string type + case schemapb.DataType_Bool: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + var value bool + if err := json.Unmarshal([]byte(str), &value); err != nil { + return fmt.Errorf("illegal value '%v' for bool type field '%s'", str, schema.GetName()) + } + field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value) + return nil + } + case schemapb.DataType_Float: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + value, err := parseFloat(str, 32, schema.GetName()) + if err != nil { + return err + } + field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, float32(value)) + return nil + } + case schemapb.DataType_Double: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + value, err := parseFloat(str, 64, schema.GetName()) + if err != nil { + return err + } + field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value) + return nil + } + case schemapb.DataType_Int8: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + value, err := strconv.ParseInt(str, 0, 8) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int8 field '%s', error: %w", str, schema.GetName(), err) + } + field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, int8(value)) + return nil + } + case schemapb.DataType_Int16: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + value, err := strconv.ParseInt(str, 0, 16) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int16 field '%s', error: %w", str, schema.GetName(), err) + } + field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, int16(value)) + return nil + } + case schemapb.DataType_Int32: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + value, err := strconv.ParseInt(str, 0, 32) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int32 field '%s', error: %w", str, schema.GetName(), err) + } + field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, int32(value)) + return nil + } + case schemapb.DataType_Int64: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + value, err := strconv.ParseInt(str, 0, 64) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for int64 field '%s', error: %w", str, schema.GetName(), err) + } + field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value) + return nil + } + case schemapb.DataType_BinaryVector: + dim, err := getFieldDimension(schema) + if err != nil { + return err + } + + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + var arr []interface{} + desc := json.NewDecoder(strings.NewReader(str)) + desc.UseNumber() + if err := desc.Decode(&arr); err != nil { + return fmt.Errorf("'%v' is not an array for binary vector field '%s'", str, schema.GetName()) + } + + // we use uint8 to represent binary vector in csv file, each uint8 value represents 8 dimensions. + if len(arr)*8 != dim { + return fmt.Errorf("bit size %d doesn't equal to vector dimension %d of field '%s'", len(arr)*8, dim, schema.GetName()) + } + + for i := 0; i < len(arr); i++ { + if num, ok := arr[i].(json.Number); ok { + value, err := strconv.ParseUint(string(num), 0, 8) + if err != nil { + return fmt.Errorf("failed to parse value '%v' for binary vector field '%s', error: %w", num, schema.GetName(), err) + } + field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, byte(value)) + } else { + return fmt.Errorf("illegal value '%v' for binary vector field '%s'", str, schema.GetName()) + } + } + + return nil + } + case schemapb.DataType_FloatVector: + dim, err := getFieldDimension(schema) + if err != nil { + return err + } + + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + var arr []interface{} + desc := json.NewDecoder(strings.NewReader(str)) + desc.UseNumber() + if err := desc.Decode(&arr); err != nil { + return fmt.Errorf("'%v' is not an array for float vector field '%s'", str, schema.GetName()) + } + + if len(arr) != dim { + return fmt.Errorf("array size %d doesn't equal to vector dimension %d of field '%s'", len(arr), dim, schema.GetName()) + } + + for i := 0; i < len(arr); i++ { + if num, ok := arr[i].(json.Number); ok { + value, err := parseFloat(string(num), 32, schema.GetName()) + if err != nil { + return err + } + field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, float32(value)) + } else { + return fmt.Errorf("illegal value '%v' for float vector field '%s'", str, schema.GetName()) + } + } + + return nil + } + case schemapb.DataType_String, schemapb.DataType_VarChar: + validators[schema.GetFieldID()].isString = true + + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, str) + return nil + } + case schemapb.DataType_JSON: + validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { + var dummy interface{} + if err := json.Unmarshal([]byte(str), &dummy); err != nil { + return fmt.Errorf("failed to parse value '%v' for JSON field '%s', error: %w", str, schema.GetName(), err) + } + field.(*storage.JSONFieldData).Data = append(field.(*storage.JSONFieldData).Data, []byte(str)) + return nil + } + default: + return fmt.Errorf("unsupport data type: %s", getTypeName(collectionSchema.Fields[i].DataType)) + } + } + return nil +} + +func (v *CSVRowConsumer) IDRange() []int64 { + return v.autoIDRange +} + +func (v *CSVRowConsumer) RowCount() int64 { + return v.rowCounter +} + +func (v *CSVRowConsumer) Handle(rows []map[storage.FieldID]string) error { + if v == nil || v.validators == nil || len(v.validators) == 0 { + log.Warn("CSV row consumer is not initialized") + return errors.New("CSV row consumer is not initialized") + } + // if rows is nil, that means read to end of file, force flush all data + if rows == nil { + err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, true) + log.Info("CSV row consumer finished") + return err + } + + // rows is not nil, flush in necessary: + // 1. data block size larger than v.blockSize will be flushed + // 2. total data size exceeds MaxTotalSizeInMemory, the largest data block will be flushed + err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, false) + if err != nil { + log.Warn("CSV row consumer: try flush data but failed", zap.Error(err)) + return fmt.Errorf("try flush data but failed, error: %w", err) + } + + // prepare autoid, no matter int64 or varchar pk, we always generate autoid since the hidden field RowIDField requires them + primaryKeyID := v.collectionInfo.PrimaryKey.FieldID + primaryValidator := v.validators[primaryKeyID] + var rowIDBegin typeutil.UniqueID + var rowIDEnd typeutil.UniqueID + if v.collectionInfo.PrimaryKey.AutoID { + if v.rowIDAllocator == nil { + log.Warn("CSV row consumer: primary keys is auto-generated but IDAllocator is nil") + return fmt.Errorf("primary keys is auto-generated but IDAllocator is nil") + } + var err error + rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows))) + if err != nil { + log.Warn("CSV row consumer: failed to generate primary keys", zap.Int("count", len(rows)), zap.Error(err)) + return fmt.Errorf("failed to generate %d primary keys, error: %w", len(rows), err) + } + if rowIDEnd-rowIDBegin != int64(len(rows)) { + log.Warn("CSV row consumer: try to generate primary keys but allocated ids are not enough", + zap.Int("count", len(rows)), zap.Int64("generated", rowIDEnd-rowIDBegin)) + return fmt.Errorf("try to generate %d primary keys but only %d keys were allocated", len(rows), rowIDEnd-rowIDBegin) + } + log.Info("CSV row consumer: auto-generate primary keys", zap.Int64("begin", rowIDBegin), zap.Int64("end", rowIDEnd)) + if primaryValidator.isString { + // if pk is varchar, no need to record auto-generated row ids + log.Warn("CSV row consumer: string type primary key connot be auto-generated") + return errors.New("string type primary key connot be auto-generated") + } + v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd) + } + + // consume rows + for i := 0; i < len(rows); i++ { + row := rows[i] + rowNumber := v.rowCounter + int64(i) + + // hash to a shard number + var shardID uint32 + var partitionID int64 + if primaryValidator.isString { + pk := row[primaryKeyID] + + // hash to shard based on pk, hash to partition if partition key exist + hash := typeutil.HashString2Uint32(pk) + shardID = hash % uint32(v.collectionInfo.ShardNum) + partitionID, err = v.hashToPartition(row, rowNumber) + if err != nil { + return err + } + + pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.StringFieldData) + pkArray.Data = append(pkArray.Data, pk) + } else { + var pk int64 + if v.collectionInfo.PrimaryKey.AutoID { + pk = rowIDBegin + int64(i) + } else { + pkStr := row[primaryKeyID] + pk, err = strconv.ParseInt(pkStr, 10, 64) + if err != nil { + log.Warn("CSV row consumer: failed to parse primary key at the row", + zap.String("value", pkStr), zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return fmt.Errorf("failed to parse primary key '%s' at the row %d, error: %w", + pkStr, rowNumber, err) + } + } + + hash, err := typeutil.Hash32Int64(pk) + if err != nil { + log.Warn("CSV row consumer: failed to hash primary key at the row", + zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return fmt.Errorf("failed to hash primary key %d at the row %d, error: %w", pk, rowNumber, err) + } + + // hash to shard based on pk, hash to partition if partition key exist + shardID = hash % uint32(v.collectionInfo.ShardNum) + partitionID, err = v.hashToPartition(row, rowNumber) + if err != nil { + return err + } + + pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.Int64FieldData) + pkArray.Data = append(pkArray.Data, pk) + } + rowIDField := v.shardsData[shardID][partitionID][common.RowIDField].(*storage.Int64FieldData) + rowIDField.Data = append(rowIDField.Data, rowIDBegin+int64(i)) + + for fieldID, validator := range v.validators { + if fieldID == v.collectionInfo.PrimaryKey.GetFieldID() { + continue + } + + value := row[fieldID] + if err := validator.convertFunc(value, v.shardsData[shardID][partitionID][fieldID]); err != nil { + log.Warn("CSV row consumer: failed to convert value for field at the row", + zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return fmt.Errorf("failed to convert value for field '%s' at the row %d, error: %w", + validator.fieldName, rowNumber, err) + } + } + } + + v.rowCounter += int64(len(rows)) + return nil +} + +// hashToPartition hash partition key to get an partition ID, return the first partition ID if no partition key exist +// CollectionInfo ensures only one partition ID in the PartitionIDs if no partition key exist +func (v *CSVRowConsumer) hashToPartition(row map[storage.FieldID]string, rowNumber int64) (int64, error) { + if v.collectionInfo.PartitionKey == nil { + if len(v.collectionInfo.PartitionIDs) != 1 { + return 0, fmt.Errorf("collection '%s' partition list is empty", v.collectionInfo.Schema.Name) + } + // no partition key, directly return the target partition id + return v.collectionInfo.PartitionIDs[0], nil + } + + partitionKeyID := v.collectionInfo.PartitionKey.GetFieldID() + partitionKeyValidator := v.validators[partitionKeyID] + value := row[partitionKeyID] + + var hashValue uint32 + if partitionKeyValidator.isString { + hashValue = typeutil.HashString2Uint32(value) + } else { + // parse the value from a string + pk, err := strconv.ParseInt(value, 10, 64) + if err != nil { + log.Warn("CSV row consumer: failed to parse partition key at the row", + zap.String("value", value), zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return 0, fmt.Errorf("failed to parse partition key '%s' at the row %d, error: %w", + value, rowNumber, err) + } + + hashValue, err = typeutil.Hash32Int64(pk) + if err != nil { + log.Warn("CSV row consumer: failed to hash partition key at the row", + zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err)) + return 0, fmt.Errorf("failed to hash partition key %d at the row %d, error: %w", pk, rowNumber, err) + } + } + + index := int64(hashValue % uint32(len(v.collectionInfo.PartitionIDs))) + return v.collectionInfo.PartitionIDs[index], nil +} diff --git a/internal/util/importutil/csv_handler_test.go b/internal/util/importutil/csv_handler_test.go new file mode 100644 index 0000000000..a2a388b6a2 --- /dev/null +++ b/internal/util/importutil/csv_handler_test.go @@ -0,0 +1,760 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "context" + "strconv" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" +) + +func Test_CSVRowConsumerNew(t *testing.T) { + ctx := context.Background() + + t.Run("nil schema", func(t *testing.T) { + consumer, err := NewCSVRowConsumer(ctx, nil, nil, 16, nil) + assert.Error(t, err) + assert.Nil(t, consumer) + }) + + t.Run("wrong schema", func(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + }, + } + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + schema.Fields[0].DataType = schemapb.DataType_None + consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) + assert.Error(t, err) + assert.Nil(t, consumer) + }) + + t.Run("primary key is autoid but no IDAllocator", func(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Name: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "uid", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + }, + } + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + + consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) + assert.Error(t, err) + assert.Nil(t, consumer) + }) + + t.Run("succeed", func(t *testing.T) { + collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1}) + assert.NoError(t, err) + + consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) + assert.NoError(t, err) + assert.NotNil(t, consumer) + }) +} + +func Test_CSVRowConsumerInitValidators(t *testing.T) { + ctx := context.Background() + consumer := &CSVRowConsumer{ + ctx: ctx, + validators: make(map[int64]*CSVValidator), + } + + collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1}) + assert.NoError(t, err) + schema := collectionInfo.Schema + err = consumer.initValidators(schema) + assert.NoError(t, err) + assert.Equal(t, len(schema.Fields), len(consumer.validators)) + for _, field := range schema.Fields { + fieldID := field.GetFieldID() + assert.Equal(t, field.GetName(), consumer.validators[fieldID].fieldName) + if field.GetDataType() != schemapb.DataType_VarChar && field.GetDataType() != schemapb.DataType_String { + assert.False(t, consumer.validators[fieldID].isString) + } else { + assert.True(t, consumer.validators[fieldID].isString) + } + } + + name2ID := make(map[string]storage.FieldID) + for _, field := range schema.Fields { + name2ID[field.GetName()] = field.GetFieldID() + } + + fields := initBlockData(schema) + assert.NotNil(t, fields) + + checkConvertFunc := func(funcName string, validVal string, invalidVal string) { + id := name2ID[funcName] + v, ok := consumer.validators[id] + assert.True(t, ok) + + fieldData := fields[id] + preNum := fieldData.RowNum() + err = v.convertFunc(validVal, fieldData) + assert.NoError(t, err) + postNum := fieldData.RowNum() + assert.Equal(t, 1, postNum-preNum) + + err = v.convertFunc(invalidVal, fieldData) + assert.Error(t, err) + } + + t.Run("check convert functions", func(t *testing.T) { + // all val is string type + validVal := "true" + invalidVal := "5" + checkConvertFunc("FieldBool", validVal, invalidVal) + + validVal = "100" + invalidVal = "128" + checkConvertFunc("FieldInt8", validVal, invalidVal) + + invalidVal = "65536" + checkConvertFunc("FieldInt16", validVal, invalidVal) + + invalidVal = "2147483648" + checkConvertFunc("FieldInt32", validVal, invalidVal) + + invalidVal = "1.2" + checkConvertFunc("FieldInt64", validVal, invalidVal) + + invalidVal = "dummy" + checkConvertFunc("FieldFloat", validVal, invalidVal) + checkConvertFunc("FieldDouble", validVal, invalidVal) + + // json type + validVal = `{"x": 5, "y": true, "z": "hello"}` + checkConvertFunc("FieldJSON", validVal, "a") + checkConvertFunc("FieldJSON", validVal, "{") + + // the binary vector dimension is 16, shoud input two uint8 values, each value should between 0~255 + validVal = "[100, 101]" + invalidVal = "[100, 1256]" + checkConvertFunc("FieldBinaryVector", validVal, invalidVal) + + invalidVal = "false" + checkConvertFunc("FieldBinaryVector", validVal, invalidVal) + invalidVal = "[100]" + checkConvertFunc("FieldBinaryVector", validVal, invalidVal) + invalidVal = "[100.2, 102.5]" + checkConvertFunc("FieldBinaryVector", validVal, invalidVal) + + // the float vector dimension is 4, each value should be valid float number + validVal = "[1,2,3,4]" + invalidVal = `[1,2,3,"dummy"]` + checkConvertFunc("FieldFloatVector", validVal, invalidVal) + invalidVal = "true" + checkConvertFunc("FieldFloatVector", validVal, invalidVal) + invalidVal = `[1]` + checkConvertFunc("FieldFloatVector", validVal, invalidVal) + }) + + t.Run("init error cases", func(t *testing.T) { + // schema is nil + err := consumer.initValidators(nil) + assert.Error(t, err) + + schema = &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: make([]*schemapb.FieldSchema, 0), + } + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 111, + Name: "FieldFloatVector", + IsPrimaryKey: false, + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "aa"}, + }, + }) + consumer.validators = make(map[int64]*CSVValidator) + err = consumer.initValidators(schema) + assert.Error(t, err) + + schema.Fields = make([]*schemapb.FieldSchema, 0) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 110, + Name: "FieldBinaryVector", + IsPrimaryKey: false, + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "aa"}, + }, + }) + + err = consumer.initValidators(schema) + assert.Error(t, err) + + // unsupported data type + schema.Fields = make([]*schemapb.FieldSchema, 0) + schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ + FieldID: 110, + Name: "dummy", + IsPrimaryKey: false, + DataType: schemapb.DataType_None, + }) + + err = consumer.initValidators(schema) + assert.Error(t, err) + }) + + t.Run("json field", func(t *testing.T) { + schema = &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + AutoID: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 102, + Name: "FieldJSON", + DataType: schemapb.DataType_JSON, + }, + }, + } + consumer.validators = make(map[int64]*CSVValidator) + err = consumer.initValidators(schema) + assert.NoError(t, err) + + v, ok := consumer.validators[102] + assert.True(t, ok) + + fields := initBlockData(schema) + assert.NotNil(t, fields) + fieldData := fields[102] + + err = v.convertFunc("{\"x\": 1, \"y\": 5}", fieldData) + assert.NoError(t, err) + assert.Equal(t, 1, fieldData.RowNum()) + + err = v.convertFunc("{}", fieldData) + assert.NoError(t, err) + assert.Equal(t, 2, fieldData.RowNum()) + + err = v.convertFunc("", fieldData) + assert.Error(t, err) + assert.Equal(t, 2, fieldData.RowNum()) + }) +} + +func Test_CSVRowConsumerHandleIntPK(t *testing.T) { + ctx := context.Background() + + t.Run("nil input", func(t *testing.T) { + var consumer *CSVRowConsumer + err := consumer.Handle(nil) + assert.Error(t, err) + }) + + schema := &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "FieldInt64", + IsPrimaryKey: true, + AutoID: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "FieldVarchar", + DataType: schemapb.DataType_VarChar, + }, + { + FieldID: 103, + Name: "FieldFloat", + DataType: schemapb.DataType_Float, + }, + }, + } + createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer { + collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs) + assert.NoError(t, err) + + idAllocator := newIDAllocator(ctx, t, nil) + consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc) + assert.NotNil(t, consumer) + assert.NoError(t, err) + + return consumer + } + + t.Run("auto pk no partition key", func(t *testing.T) { + flushErrFunc := func(fields BlockData, shard int, partID int64) error { + return errors.New("dummy error") + } + + // rows to input + inputRowCount := 100 + input := make([]map[storage.FieldID]string, inputRowCount) + for i := 0; i < inputRowCount; i++ { + input[i] = map[storage.FieldID]string{ + 102: "string", + 103: "122.5", + } + } + + shardNum := int32(2) + partitionID := int64(1) + consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushErrFunc) + consumer.rowIDAllocator = newIDAllocator(ctx, t, errors.New("error")) + + waitFlushRowCount := 10 + fieldData := createFieldsData(schema, waitFlushRowCount) + consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID}) + + // nil input will trigger force flush, flushErrFunc returns error + err := consumer.Handle(nil) + assert.Error(t, err) + + // optional flush, flushErrFunc returns error + err = consumer.Handle(input) + assert.Error(t, err) + + // reset flushFunc + var callTime int32 + var flushedRowCount int + consumer.callFlushFunc = func(fields BlockData, shard int, partID int64) error { + callTime++ + assert.Less(t, int32(shard), shardNum) + assert.Equal(t, partitionID, partID) + assert.Greater(t, len(fields), 0) + for _, v := range fields { + assert.Greater(t, v.RowNum(), 0) + } + flushedRowCount += fields[102].RowNum() + return nil + } + // optional flush succeed, each shard has 10 rows, idErrAllocator returns error + err = consumer.Handle(input) + assert.Error(t, err) + assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount) + assert.Equal(t, shardNum, callTime) + + // optional flush again, large blockSize, nothing flushed, idAllocator returns error + callTime = int32(0) + flushedRowCount = 0 + consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID}) + consumer.rowIDAllocator = nil + consumer.blockSize = 8 * 1024 * 1024 + err = consumer.Handle(input) + assert.Error(t, err) + assert.Equal(t, 0, flushedRowCount) + assert.Equal(t, int32(0), callTime) + + // idAllocator is ok, consume 100 rows, the previous shardsData(10 rows per shard) is flushed + callTime = int32(0) + flushedRowCount = 0 + consumer.blockSize = 1 + consumer.rowIDAllocator = newIDAllocator(ctx, t, nil) + err = consumer.Handle(input) + assert.NoError(t, err) + assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount) + assert.Equal(t, shardNum, callTime) + assert.Equal(t, int64(inputRowCount), consumer.RowCount()) + assert.Equal(t, 2, len(consumer.IDRange())) + assert.Equal(t, int64(1), consumer.IDRange()[0]) + assert.Equal(t, int64(1+inputRowCount), consumer.IDRange()[1]) + + // call handle again, the 100 rows are flushed + callTime = int32(0) + flushedRowCount = 0 + err = consumer.Handle(nil) + assert.NoError(t, err) + assert.Equal(t, inputRowCount, flushedRowCount) + assert.Equal(t, shardNum, callTime) + }) + + schema.Fields[0].AutoID = false + + t.Run("manual pk no partition key", func(t *testing.T) { + shardNum := int32(1) + partitionID := int64(100) + + var callTime int32 + var flushedRowCount int + flushFunc := func(fields BlockData, shard int, partID int64) error { + callTime++ + assert.Less(t, int32(shard), shardNum) + assert.Equal(t, partitionID, partID) + assert.Greater(t, len(fields), 0) + flushedRowCount += fields[102].RowNum() + return nil + } + + consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc) + + // failed to convert pk to int value + input := make([]map[storage.FieldID]string, 1) + input[0] = map[int64]string{ + 101: "abc", + 102: "string", + 103: "11.11", + } + + err := consumer.Handle(input) + assert.Error(t, err) + + // failed to hash to partition + input[0] = map[int64]string{ + 101: "99", + 102: "string", + 103: "11.11", + } + consumer.collectionInfo.PartitionIDs = nil + err = consumer.Handle(input) + assert.Error(t, err) + consumer.collectionInfo.PartitionIDs = []int64{partitionID} + + // failed to convert value + input[0] = map[int64]string{ + 101: "99", + 102: "string", + 103: "abc.11", + } + err = consumer.Handle(input) + assert.Error(t, err) + consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID}) // in-memory data is dirty, reset + + // succeed, consum 1 row + input[0] = map[int64]string{ + 101: "99", + 102: "string", + 103: "11.11", + } + err = consumer.Handle(input) + assert.NoError(t, err) + assert.Equal(t, int64(1), consumer.RowCount()) + assert.Equal(t, 0, len(consumer.IDRange())) + + // call handle again, the 1 row is flushed + callTime = int32(0) + flushedRowCount = 0 + err = consumer.Handle(nil) + assert.NoError(t, err) + assert.Equal(t, 1, flushedRowCount) + assert.Equal(t, shardNum, callTime) + }) + + schema.Fields[1].IsPartitionKey = true + + t.Run("manual pk with partition key", func(t *testing.T) { + // 10 partitions + partitionIDs := make([]int64, 0) + for i := 0; i < 10; i++ { + partitionIDs = append(partitionIDs, int64(i)) + } + + shardNum := int32(2) + var flushedRowCount int + flushFunc := func(fields BlockData, shard int, partID int64) error { + assert.Less(t, int32(shard), shardNum) + assert.Contains(t, partitionIDs, partID) + assert.Greater(t, len(fields), 0) + flushedRowCount += fields[102].RowNum() + return nil + } + + consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc) + + // rows to input + inputRowCount := 100 + input := make([]map[storage.FieldID]string, inputRowCount) + for i := 0; i < inputRowCount; i++ { + input[i] = map[int64]string{ + 101: strconv.Itoa(i), + 102: "partitionKey_" + strconv.Itoa(i), + 103: "6.18", + } + } + + // 100 rows are consumed to different partitions + err := consumer.Handle(input) + assert.NoError(t, err) + assert.Equal(t, int64(inputRowCount), consumer.RowCount()) + + // call handle again, 100 rows are flushed + flushedRowCount = 0 + err = consumer.Handle(nil) + assert.NoError(t, err) + assert.Equal(t, inputRowCount, flushedRowCount) + }) +} + +func Test_CSVRowConsumerHandleVarcharPK(t *testing.T) { + ctx := context.Background() + + schema := &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 101, + Name: "FieldVarchar", + IsPrimaryKey: true, + AutoID: false, + DataType: schemapb.DataType_VarChar, + }, + { + FieldID: 102, + Name: "FieldInt64", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 103, + Name: "FieldFloat", + DataType: schemapb.DataType_Float, + }, + }, + } + + createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer { + collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs) + assert.NoError(t, err) + + idAllocator := newIDAllocator(ctx, t, nil) + consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc) + assert.NotNil(t, consumer) + assert.NoError(t, err) + + return consumer + } + + t.Run("no partition key", func(t *testing.T) { + shardNum := int32(2) + partitionID := int64(1) + var callTime int32 + var flushedRowCount int + flushFunc := func(fields BlockData, shard int, partID int64) error { + callTime++ + assert.Less(t, int32(shard), shardNum) + assert.Equal(t, partitionID, partID) + assert.Greater(t, len(fields), 0) + for _, v := range fields { + assert.Greater(t, v.RowNum(), 0) + } + flushedRowCount += fields[102].RowNum() + return nil + } + + consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc) + consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID}) + + // string type primary key cannot be auto-generated + input := make([]map[storage.FieldID]string, 1) + input[0] = map[storage.FieldID]string{ + 101: "primaryKey_0", + 102: "1", + 103: "1.252", + } + + consumer.collectionInfo.PrimaryKey.AutoID = true + err := consumer.Handle(input) + assert.Error(t, err) + consumer.collectionInfo.PrimaryKey.AutoID = false + + // failed to hash to partition + consumer.collectionInfo.PartitionIDs = nil + err = consumer.Handle(input) + assert.Error(t, err) + consumer.collectionInfo.PartitionIDs = []int64{partitionID} + + // rows to input + inputRowCount := 100 + input = make([]map[storage.FieldID]string, inputRowCount) + for i := 0; i < inputRowCount; i++ { + input[i] = map[int64]string{ + 101: "primaryKey_" + strconv.Itoa(i), + 102: strconv.Itoa(i), + 103: "6.18", + } + } + + err = consumer.Handle(input) + assert.NoError(t, err) + assert.Equal(t, int64(inputRowCount), consumer.RowCount()) + assert.Equal(t, 0, len(consumer.IDRange())) + + // call handle again, 100 rows are flushed + err = consumer.Handle(nil) + assert.NoError(t, err) + assert.Equal(t, inputRowCount, flushedRowCount) + assert.Equal(t, shardNum, callTime) + }) + + schema.Fields[1].IsPartitionKey = true + t.Run("has partition key", func(t *testing.T) { + partitionIDs := make([]int64, 0) + for i := 0; i < 10; i++ { + partitionIDs = append(partitionIDs, int64(i)) + } + + shardNum := int32(2) + var flushedRowCount int + flushFunc := func(fields BlockData, shard int, partID int64) error { + assert.Less(t, int32(shard), shardNum) + assert.Contains(t, partitionIDs, partID) + assert.Greater(t, len(fields), 0) + flushedRowCount += fields[102].RowNum() + return nil + } + + consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc) + + // rows to input + inputRowCount := 100 + input := make([]map[storage.FieldID]string, inputRowCount) + for i := 0; i < inputRowCount; i++ { + input[i] = map[int64]string{ + 101: "primaryKey_" + strconv.Itoa(i), + 102: strconv.Itoa(i), + 103: "6.18", + } + } + + err := consumer.Handle(input) + assert.NoError(t, err) + assert.Equal(t, int64(inputRowCount), consumer.RowCount()) + assert.Equal(t, 0, len(consumer.IDRange())) + + // call handle again, 100 rows are flushed + err = consumer.Handle(nil) + assert.NoError(t, err) + assert.Equal(t, inputRowCount, flushedRowCount) + }) +} + +func Test_CSVRowConsumerHashToPartition(t *testing.T) { + ctx := context.Background() + + schema := &schemapb.CollectionSchema{ + Name: "schema", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "ID", + IsPrimaryKey: true, + AutoID: false, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 101, + Name: "FieldVarchar", + DataType: schemapb.DataType_VarChar, + }, + { + FieldID: 102, + Name: "FieldInt64", + DataType: schemapb.DataType_Int64, + }, + }, + } + + partitionID := int64(1) + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{partitionID}) + assert.NoError(t, err) + consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) + assert.NoError(t, err) + assert.NotNil(t, consumer) + input := map[int64]string{ + 100: "1", + 101: "abc", + 102: "100", + } + t.Run("no partition key", func(t *testing.T) { + partID, err := consumer.hashToPartition(input, 0) + assert.NoError(t, err) + assert.Equal(t, partitionID, partID) + }) + + t.Run("partition list is empty", func(t *testing.T) { + collectionInfo.PartitionIDs = []int64{} + partID, err := consumer.hashToPartition(input, 0) + assert.Error(t, err) + assert.Equal(t, int64(0), partID) + collectionInfo.PartitionIDs = []int64{partitionID} + }) + + schema.Fields[1].IsPartitionKey = true + err = collectionInfo.resetSchema(schema) + assert.NoError(t, err) + collectionInfo.PartitionIDs = []int64{1, 2, 3} + + t.Run("varchar partition key", func(t *testing.T) { + input = map[int64]string{ + 100: "1", + 101: "abc", + 102: "100", + } + + partID, err := consumer.hashToPartition(input, 0) + assert.NoError(t, err) + assert.Contains(t, collectionInfo.PartitionIDs, partID) + }) + + schema.Fields[1].IsPartitionKey = false + schema.Fields[2].IsPartitionKey = true + err = collectionInfo.resetSchema(schema) + assert.NoError(t, err) + + t.Run("int64 partition key", func(t *testing.T) { + input = map[int64]string{ + 100: "1", + 101: "abc", + 102: "ab0", + } + // parse int failed + partID, err := consumer.hashToPartition(input, 0) + assert.Error(t, err) + assert.Equal(t, int64(0), partID) + + // succeed + input[102] = "100" + partID, err = consumer.hashToPartition(input, 0) + assert.NoError(t, err) + assert.Contains(t, collectionInfo.PartitionIDs, partID) + }) +} diff --git a/internal/util/importutil/csv_parser.go b/internal/util/importutil/csv_parser.go new file mode 100644 index 0000000000..3c4cb49b43 --- /dev/null +++ b/internal/util/importutil/csv_parser.go @@ -0,0 +1,317 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strconv" + "strings" + + "github.com/cockroachdb/errors" + "go.uber.org/zap" + + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/typeutil" +) + +type CSVParser struct { + ctx context.Context // for canceling parse process + collectionInfo *CollectionInfo // collection details including schema + bufRowCount int // max rows in a buffer + fieldsName []string // fieldsName(header name) in the csv file + updateProgressFunc func(percent int64) // update working progress percent value +} + +func NewCSVParser(ctx context.Context, collectionInfo *CollectionInfo, updateProgressFunc func(percent int64)) (*CSVParser, error) { + if collectionInfo == nil { + log.Warn("CSV parser: collection schema is nil") + return nil, errors.New("collection schema is nil") + } + + parser := &CSVParser{ + ctx: ctx, + collectionInfo: collectionInfo, + bufRowCount: 1024, + fieldsName: make([]string, 0), + updateProgressFunc: updateProgressFunc, + } + parser.SetBufSize() + return parser, nil +} + +func (p *CSVParser) SetBufSize() { + schema := p.collectionInfo.Schema + sizePerRecord, _ := typeutil.EstimateSizePerRecord(schema) + if sizePerRecord <= 0 { + return + } + + bufRowCount := p.bufRowCount + for { + if bufRowCount*sizePerRecord > ReadBufferSize { + bufRowCount-- + } else { + break + } + } + if bufRowCount <= 0 { + bufRowCount = 1 + } + log.Info("CSV parser: reset bufRowCount", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufRowCount", bufRowCount)) + p.bufRowCount = bufRowCount +} + +func (p *CSVParser) combineDynamicRow(dynamicValues map[string]string, row map[storage.FieldID]string) error { + if p.collectionInfo.DynamicField == nil { + return nil + } + + dynamicFieldID := p.collectionInfo.DynamicField.GetFieldID() + // combine the dynamic field value + // valid input: + // id,vector,x,$meta id,vector,$meta + // case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}" + // case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}" + // case3: 1,"[]",,"{""x"": 8}" + // case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}" + // case5: 1,"[]",, + value, ok := row[dynamicFieldID] + // ignore empty string field + if value == "" { + ok = false + } + if len(dynamicValues) > 0 { + mp := make(map[string]interface{}) + if ok { + // case 1/2 + // $meta is JSON type field, we first convert it to map[string]interface{} + // then merge other dynamic field into it + desc := json.NewDecoder(strings.NewReader(value)) + desc.UseNumber() + if err := desc.Decode(&mp); err != nil { + log.Warn("CSV parser: illegal value for dynamic field, not a JSON object") + return errors.New("illegal value for dynamic field, not a JSON object") + } + } + // case 4 + for k, v := range dynamicValues { + // ignore empty string field + if v == "" { + continue + } + var value interface{} + + desc := json.NewDecoder(strings.NewReader(v)) + desc.UseNumber() + if err := desc.Decode(&value); err != nil { + // Decode a string will cause error, like "abcd" + mp[k] = v + continue + } + + if num, ok := value.(json.Number); ok { + // Decode may convert "123ab" to 123, so need additional check + if _, err := strconv.ParseFloat(v, 64); err != nil { + mp[k] = v + } else { + mp[k] = num + } + } else if arr, ok := value.([]interface{}); ok { + mp[k] = arr + } else if obj, ok := value.(map[string]interface{}); ok { + mp[k] = obj + } else if b, ok := value.(bool); ok { + mp[k] = b + } + } + bs, err := json.Marshal(mp) + if err != nil { + log.Warn("CSV parser: illegal value for dynamic field, not a JSON object") + return errors.New("illegal value for dynamic field, not a JSON object") + } + row[dynamicFieldID] = string(bs) + } else if !ok && len(dynamicValues) == 0 { + // case 5 + row[dynamicFieldID] = "{}" + } + // else case 3 + + return nil +} + +func (p *CSVParser) verifyRow(raw []string) (map[storage.FieldID]string, error) { + row := make(map[storage.FieldID]string) + dynamicValues := make(map[string]string) + + for i := 0; i < len(p.fieldsName); i++ { + fieldName := p.fieldsName[i] + fieldID, ok := p.collectionInfo.Name2FieldID[fieldName] + + if fieldID == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() { + // primary key is auto-id, no need to provide + log.Warn("CSV parser: the primary key is auto-generated, no need to provide", zap.String("fieldName", fieldName)) + return nil, fmt.Errorf("the primary key '%s' is auto-generated, no need to provide", fieldName) + } + + if ok { + row[fieldID] = raw[i] + } else if p.collectionInfo.DynamicField != nil { + // collection have dynamic field. put it to dynamicValues + dynamicValues[fieldName] = raw[i] + } else { + // no dynamic field. if user provided redundant field, return error + log.Warn("CSV parser: the field is not defined in collection schema", zap.String("fieldName", fieldName)) + return nil, fmt.Errorf("the field '%s' is not defined in collection schema", fieldName) + } + } + // some fields not provided? + if len(row) != len(p.collectionInfo.Name2FieldID) { + for k, v := range p.collectionInfo.Name2FieldID { + if p.collectionInfo.DynamicField != nil && v == p.collectionInfo.DynamicField.GetFieldID() { + // ignore dyanmic field, user don't have to provide values for dynamic field + continue + } + + if v == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() { + // ignore auto-generaed primary key + continue + } + _, ok := row[v] + if !ok { + // not auto-id primary key, no dynamic field, must provide value + log.Warn("CSV parser: a field value is missed", zap.String("fieldName", k)) + return nil, fmt.Errorf("value of field '%s' is missed", k) + } + } + } + // combine the redundant pairs into dynamic field(if has) + err := p.combineDynamicRow(dynamicValues, row) + if err != nil { + log.Warn("CSV parser: failed to combine dynamic values", zap.Error(err)) + return nil, err + } + + return row, nil +} + +func (p *CSVParser) ParseRows(reader *IOReader, handle CSVRowHandler) error { + if reader == nil || handle == nil { + log.Warn("CSV Parser: CSV parse handle is nil") + return errors.New("CSV parse handle is nil") + } + // discard bom in the file + RuneScanner := reader.r.(io.RuneScanner) + bom, _, err := RuneScanner.ReadRune() + if err == io.EOF { + log.Info("CSV Parser: row count is 0") + return nil + } + if err != nil { + return err + } + if bom != '\ufeff' { + RuneScanner.UnreadRune() + } + r := NewReader(reader.r) + + oldPercent := int64(0) + updateProgress := func() { + if p.updateProgressFunc != nil && reader.fileSize > 0 { + percent := (r.InputOffset() * ProgressValueForPersist) / reader.fileSize + if percent > oldPercent { // avoid too many log + log.Debug("CSV parser: working progress", zap.Int64("offset", r.InputOffset()), + zap.Int64("fileSize", reader.fileSize), zap.Int64("percent", percent)) + } + oldPercent = percent + p.updateProgressFunc(percent) + } + } + isEmpty := true + for { + // read the fields value + fieldsName, err := r.Read() + if err == io.EOF { + break + } else if err != nil { + log.Warn("CSV Parser: failed to parse the field value", zap.Error(err)) + return fmt.Errorf("failed to read the field value, error: %w", err) + } + p.fieldsName = fieldsName + // read buffer + buf := make([]map[storage.FieldID]string, 0, p.bufRowCount) + for { + // read the row value + values, err := r.Read() + + if err == io.EOF { + break + } else if err != nil { + log.Warn("CSV parser: failed to parse row value", zap.Error(err)) + return fmt.Errorf("failed to parse row value, error: %w", err) + } + + row, err := p.verifyRow(values) + if err != nil { + return err + } + + updateProgress() + + buf = append(buf, row) + if len(buf) >= p.bufRowCount { + isEmpty = false + if err = handle.Handle(buf); err != nil { + log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err)) + return fmt.Errorf("failed to convert row value to entity, error: %w", err) + } + // clean the buffer + buf = make([]map[storage.FieldID]string, 0, p.bufRowCount) + } + } + if len(buf) > 0 { + isEmpty = false + if err = handle.Handle(buf); err != nil { + log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err)) + return fmt.Errorf("failed to convert row value to entity, error: %w", err) + } + } + + // outside context might be canceled(service stop, or future enhancement for canceling import task) + if isCanceled(p.ctx) { + log.Warn("CSV parser: import task was canceled") + return errors.New("import task was canceled") + } + // nolint + // this break means we require the first row must be fieldsName + break + } + + // empty file is allowed, don't return error + if isEmpty { + log.Info("CSV Parser: row count is 0") + return nil + } + + updateProgress() + + // send nil to notify the handler all have done + return handle.Handle(nil) +} diff --git a/internal/util/importutil/csv_parser_test.go b/internal/util/importutil/csv_parser_test.go new file mode 100644 index 0000000000..70738e7fc4 --- /dev/null +++ b/internal/util/importutil/csv_parser_test.go @@ -0,0 +1,414 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package importutil + +import ( + "context" + "strings" + "testing" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/common" +) + +type mockCSVRowConsumer struct { + handleErr error + rows []map[storage.FieldID]string + handleCount int +} + +func (v *mockCSVRowConsumer) Handle(rows []map[storage.FieldID]string) error { + if v.handleErr != nil { + return v.handleErr + } + if rows != nil { + v.rows = append(v.rows, rows...) + } + v.handleCount++ + return nil +} + +func Test_CSVParserAdjustBufSize(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := sampleSchema() + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + parser, err := NewCSVParser(ctx, collectionInfo, nil) + assert.NoError(t, err) + assert.NotNil(t, parser) + assert.Greater(t, parser.bufRowCount, 0) + // huge row + schema.Fields[9].TypeParams = []*commonpb.KeyValuePair{ + {Key: common.DimKey, Value: "32768"}, + } + parser, err = NewCSVParser(ctx, collectionInfo, nil) + assert.NoError(t, err) + assert.NotNil(t, parser) + assert.Greater(t, parser.bufRowCount, 0) +} + +func Test_CSVParserParseRows_IntPK(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := sampleSchema() + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + parser, err := NewCSVParser(ctx, collectionInfo, nil) + assert.NoError(t, err) + assert.NotNil(t, parser) + + consumer := &mockCSVRowConsumer{ + handleErr: nil, + rows: make([]map[int64]string, 0), + handleCount: 0, + } + + reader := strings.NewReader( + `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector + true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`) + + t.Run("parse success", func(t *testing.T) { + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.NoError(t, err) + + // empty file + reader = strings.NewReader(``) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, consumer) + assert.NoError(t, err) + + // only have headers no value row + reader = strings.NewReader(`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.NoError(t, err) + + // csv file have bom + reader = strings.NewReader(`\ufeffFieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.NoError(t, err) + }) + + t.Run("error cases", func(t *testing.T) { + // handler is nil + + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, nil) + assert.Error(t, err) + + // csv parse error, fields len error + reader := strings.NewReader( + `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector + 0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.Error(t, err) + + // redundant field + reader = strings.NewReader( + `dummy,FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector + 1,true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.Error(t, err) + + // field missed + reader = strings.NewReader( + `FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector + 0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"`) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.Error(t, err) + + // handle() error + content := `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector + true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]"` + consumer.handleErr = errors.New("error") + reader = strings.NewReader(content) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.Error(t, err) + + // canceled + consumer.handleErr = nil + cancel() + reader = strings.NewReader(content) + err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) + assert.Error(t, err) + }) +} + +func Test_CSVParserCombineDynamicRow(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 106, + Name: "FieldID", + IsPrimaryKey: true, + AutoID: false, + Description: "int64", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 113, + Name: "FieldDynamic", + IsPrimaryKey: false, + IsDynamic: true, + Description: "dynamic field", + DataType: schemapb.DataType_JSON, + }, + }, + } + + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + parser, err := NewCSVParser(ctx, collectionInfo, nil) + assert.NoError(t, err) + assert.NotNil(t, parser) + + // valid input: + // id,vector,x,$meta id,vector,$meta + // case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}" + // case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}" + // case3: 1,"[]",,"{""x"": 8}" + // case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}" + // case5: 1,"[]",, + + t.Run("value combined for dynamic field", func(t *testing.T) { + dynamicValues := map[string]string{ + "x": "88", + } + row := map[storage.FieldID]string{ + 106: "1", + 113: `{"y": 8}`, + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Contains(t, row, int64(113)) + assert.Contains(t, row[113], "x") + assert.Contains(t, row[113], "y") + + row = map[storage.FieldID]string{ + 106: "1", + 113: `{}`, + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Contains(t, row, int64(113)) + assert.Contains(t, row[113], "x") + }) + + t.Run("JSON format string/object for dynamic field", func(t *testing.T) { + dynamicValues := map[string]string{} + row := map[storage.FieldID]string{ + 106: "1", + 113: `{"x": 8}`, + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Contains(t, row, int64(113)) + }) + + t.Run("dynamic field is hidden", func(t *testing.T) { + dynamicValues := map[string]string{ + "x": "8", + } + row := map[storage.FieldID]string{ + 106: "1", + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Contains(t, row, int64(113)) + assert.Contains(t, row, int64(113)) + assert.Contains(t, row[113], "x") + }) + + t.Run("no values for dynamic field", func(t *testing.T) { + dynamicValues := map[string]string{} + row := map[storage.FieldID]string{ + 106: "1", + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Contains(t, row, int64(113)) + assert.Equal(t, "{}", row[113]) + }) + + t.Run("empty value for dynamic field", func(t *testing.T) { + dynamicValues := map[string]string{ + "x": "", + } + row := map[storage.FieldID]string{ + 106: "1", + 113: `{"y": 8}`, + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Contains(t, row, int64(113)) + assert.Contains(t, row[113], "y") + assert.NotContains(t, row[113], "x") + + row = map[storage.FieldID]string{ + 106: "1", + 113: "", + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Equal(t, "{}", row[113]) + + dynamicValues = map[string]string{ + "x": "5", + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.Contains(t, row[113], "x") + }) + + t.Run("invalid input for dynamic field", func(t *testing.T) { + dynamicValues := map[string]string{ + "x": "8", + } + row := map[storage.FieldID]string{ + 106: "1", + 113: "5", + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.Error(t, err) + + row = map[storage.FieldID]string{ + 106: "1", + 113: "abc", + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.Error(t, err) + }) + + t.Run("not allow dynamic values if no dynamic field", func(t *testing.T) { + parser.collectionInfo.DynamicField = nil + dynamicValues := map[string]string{ + "x": "8", + } + row := map[storage.FieldID]string{ + 106: "1", + } + err = parser.combineDynamicRow(dynamicValues, row) + assert.NoError(t, err) + assert.NotContains(t, row, int64(113)) + }) +} + +func Test_CSVParserVerifyRow(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + schema := &schemapb.CollectionSchema{ + Name: "schema", + Description: "schema", + EnableDynamicField: true, + Fields: []*schemapb.FieldSchema{ + { + FieldID: 106, + Name: "FieldID", + IsPrimaryKey: true, + AutoID: false, + Description: "int64", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 113, + Name: "FieldDynamic", + IsPrimaryKey: false, + IsDynamic: true, + Description: "dynamic field", + DataType: schemapb.DataType_JSON, + }, + }, + } + + collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) + assert.NoError(t, err) + parser, err := NewCSVParser(ctx, collectionInfo, nil) + assert.NoError(t, err) + assert.NotNil(t, parser) + + t.Run("not auto-id, dynamic field provided", func(t *testing.T) { + parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"} + raw := []string{"1", `{"x": 8}`, "true"} + row, err := parser.verifyRow(raw) + assert.NoError(t, err) + assert.Contains(t, row, int64(106)) + assert.Contains(t, row, int64(113)) + assert.Contains(t, row[113], "x") + assert.Contains(t, row[113], "y") + }) + + t.Run("not auto-id, dynamic field not provided", func(t *testing.T) { + parser.fieldsName = []string{"FieldID"} + raw := []string{"1"} + row, err := parser.verifyRow(raw) + assert.NoError(t, err) + assert.Contains(t, row, int64(106)) + assert.Contains(t, row, int64(113)) + assert.Contains(t, "{}", row[113]) + }) + + t.Run("not auto-id, invalid input dynamic field", func(t *testing.T) { + parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"} + raw := []string{"1", "true", "true"} + _, err = parser.verifyRow(raw) + assert.Error(t, err) + }) + + schema.Fields[0].AutoID = true + err = collectionInfo.resetSchema(schema) + assert.NoError(t, err) + t.Run("no need to provide value for auto-id", func(t *testing.T) { + parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"} + raw := []string{"1", `{"x": 8}`, "true"} + _, err := parser.verifyRow(raw) + assert.Error(t, err) + + parser.fieldsName = []string{"FieldDynamic", "y"} + raw = []string{`{"x": 8}`, "true"} + row, err := parser.verifyRow(raw) + assert.NoError(t, err) + assert.Contains(t, row, int64(113)) + }) + + schema.Fields[1].IsDynamic = false + err = collectionInfo.resetSchema(schema) + assert.NoError(t, err) + t.Run("auto id, no dynamic field", func(t *testing.T) { + parser.fieldsName = []string{"FieldDynamic", "y"} + raw := []string{`{"x": 8}`, "true"} + _, err := parser.verifyRow(raw) + assert.Error(t, err) + + // miss FieldDynamic + parser.fieldsName = []string{} + raw = []string{} + _, err = parser.verifyRow(raw) + assert.Error(t, err) + }) +} diff --git a/internal/util/importutil/csv_reader.go b/internal/util/importutil/csv_reader.go new file mode 100644 index 0000000000..b0c51d9398 --- /dev/null +++ b/internal/util/importutil/csv_reader.go @@ -0,0 +1,470 @@ +// Copied from go 1.20 as go 1.18 csv reader not implement inputOffset + +package importutil + +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package csv reads and writes comma-separated values (CSV) files. +// There are many kinds of CSV files; this package supports the format +// described in RFC 4180. +// +// A csv file contains zero or more records of one or more fields per record. +// Each record is separated by the newline character. The final record may +// optionally be followed by a newline character. +// +// field1,field2,field3 +// +// White space is considered part of a field. +// +// Carriage returns before newline characters are silently removed. +// +// Blank lines are ignored. A line with only whitespace characters (excluding +// the ending newline character) is not considered a blank line. +// +// Fields which start and stop with the quote character " are called +// quoted-fields. The beginning and ending quote are not part of the +// field. +// +// The source: +// +// normal string,"quoted-field" +// +// results in the fields +// +// {`normal string`, `quoted-field`} +// +// Within a quoted-field a quote character followed by a second quote +// character is considered a single quote. +// +// "the ""word"" is true","a ""quoted-field""" +// +// results in +// +// {`the "word" is true`, `a "quoted-field"`} +// +// Newlines and commas may be included in a quoted-field +// +// "Multi-line +// field","comma is ," +// +// results in +// +// {`Multi-line +// field`, `comma is ,`} + +import ( + "bufio" + "bytes" + "fmt" + "io" + "unicode" + "unicode/utf8" + + "github.com/cockroachdb/errors" +) + +// A ParseError is returned for parsing errors. +// Line numbers are 1-indexed and columns are 0-indexed. +type ParseError struct { + StartLine int // Line where the record starts + Line int // Line where the error occurred + Column int // Column (1-based byte index) where the error occurred + Err error // The actual error +} + +func (e *ParseError) Error() string { + if e.Err == ErrFieldCount { + return fmt.Sprintf("record on line %d: %v", e.Line, e.Err) + } + if e.StartLine != e.Line { + return fmt.Sprintf("record on line %d; parse error on line %d, column %d: %v", e.StartLine, e.Line, e.Column, e.Err) + } + return fmt.Sprintf("parse error on line %d, column %d: %v", e.Line, e.Column, e.Err) +} + +func (e *ParseError) Unwrap() error { return e.Err } + +// These are the errors that can be returned in ParseError.Err. +var ( + ErrBareQuote = errors.New("bare \" in non-quoted-field") + ErrQuote = errors.New("extraneous or missing \" in quoted-field") + ErrFieldCount = errors.New("wrong number of fields") + + // Deprecated: ErrTrailingComma is no longer used. + ErrTrailingComma = errors.New("extra delimiter at end of line") +) + +var errInvalidDelim = errors.New("csv: invalid field or comment delimiter") + +func validDelim(r rune) bool { + return r != 0 && r != '"' && r != '\r' && r != '\n' && utf8.ValidRune(r) && r != utf8.RuneError +} + +// A Reader reads records from a CSV-encoded file. +// +// As returned by NewReader, a Reader expects input conforming to RFC 4180. +// The exported fields can be changed to customize the details before the +// first call to Read or ReadAll. +// +// The Reader converts all \r\n sequences in its input to plain \n, +// including in multiline field values, so that the returned data does +// not depend on which line-ending convention an input file uses. +type Reader struct { + // Comma is the field delimiter. + // It is set to comma (',') by NewReader. + // Comma must be a valid rune and must not be \r, \n, + // or the Unicode replacement character (0xFFFD). + Comma rune + + // Comment, if not 0, is the comment character. Lines beginning with the + // Comment character without preceding whitespace are ignored. + // With leading whitespace the Comment character becomes part of the + // field, even if TrimLeadingSpace is true. + // Comment must be a valid rune and must not be \r, \n, + // or the Unicode replacement character (0xFFFD). + // It must also not be equal to Comma. + Comment rune + + // FieldsPerRecord is the number of expected fields per record. + // If FieldsPerRecord is positive, Read requires each record to + // have the given number of fields. If FieldsPerRecord is 0, Read sets it to + // the number of fields in the first record, so that future records must + // have the same field count. If FieldsPerRecord is negative, no check is + // made and records may have a variable number of fields. + FieldsPerRecord int + + // If LazyQuotes is true, a quote may appear in an unquoted field and a + // non-doubled quote may appear in a quoted field. + LazyQuotes bool + + // If TrimLeadingSpace is true, leading white space in a field is ignored. + // This is done even if the field delimiter, Comma, is white space. + TrimLeadingSpace bool + + // ReuseRecord controls whether calls to Read may return a slice sharing + // the backing array of the previous call's returned slice for performance. + // By default, each call to Read returns newly allocated memory owned by the caller. + ReuseRecord bool + + // Deprecated: TrailingComma is no longer used. + TrailingComma bool + + r *bufio.Reader + + // numLine is the current line being read in the CSV file. + numLine int + + // offset is the input stream byte offset of the current reader position. + offset int64 + + // rawBuffer is a line buffer only used by the readLine method. + rawBuffer []byte + + // recordBuffer holds the unescaped fields, one after another. + // The fields can be accessed by using the indexes in fieldIndexes. + // E.g., For the row `a,"b","c""d",e`, recordBuffer will contain `abc"de` + // and fieldIndexes will contain the indexes [1, 2, 5, 6]. + recordBuffer []byte + + // fieldIndexes is an index of fields inside recordBuffer. + // The i'th field ends at offset fieldIndexes[i] in recordBuffer. + fieldIndexes []int + + // fieldPositions is an index of field positions for the + // last record returned by Read. + fieldPositions []position + + // lastRecord is a record cache and only used when ReuseRecord == true. + lastRecord []string +} + +// NewReader returns a new Reader that reads from r. +func NewReader(r io.Reader) *Reader { + return &Reader{ + Comma: ',', + r: bufio.NewReader(r), + } +} + +// Read reads one record (a slice of fields) from r. +// If the record has an unexpected number of fields, +// Read returns the record along with the error ErrFieldCount. +// If the record contains a field that cannot be parsed, +// Read returns a partial record along with the parse error. +// The partial record contains all fields read before the error. +// If there is no data left to be read, Read returns nil, io.EOF. +// If ReuseRecord is true, the returned slice may be shared +// between multiple calls to Read. +func (r *Reader) Read() (record []string, err error) { + if r.ReuseRecord { + record, err = r.readRecord(r.lastRecord) + r.lastRecord = record + } else { + record, err = r.readRecord(nil) + } + return record, err +} + +// FieldPos returns the line and column corresponding to +// the start of the field with the given index in the slice most recently +// returned by Read. Numbering of lines and columns starts at 1; +// columns are counted in bytes, not runes. +// +// If this is called with an out-of-bounds index, it panics. +func (r *Reader) FieldPos(field int) (line, column int) { + if field < 0 || field >= len(r.fieldPositions) { + panic("out of range index passed to FieldPos") + } + p := &r.fieldPositions[field] + return p.line, p.col +} + +// InputOffset returns the input stream byte offset of the current reader +// position. The offset gives the location of the end of the most recently +// read row and the beginning of the next row. +func (r *Reader) InputOffset() int64 { + return r.offset +} + +// pos holds the position of a field in the current line. +type position struct { + line, col int +} + +// ReadAll reads all the remaining records from r. +// Each record is a slice of fields. +// A successful call returns err == nil, not err == io.EOF. Because ReadAll is +// defined to read until EOF, it does not treat end of file as an error to be +// reported. +func (r *Reader) ReadAll() (records [][]string, err error) { + for { + record, err := r.readRecord(nil) + if err == io.EOF { + return records, nil + } + if err != nil { + return nil, err + } + records = append(records, record) + } +} + +// readLine reads the next line (with the trailing endline). +// If EOF is hit without a trailing endline, it will be omitted. +// If some bytes were read, then the error is never io.EOF. +// The result is only valid until the next call to readLine. +func (r *Reader) readLine() ([]byte, error) { + line, err := r.r.ReadSlice('\n') + if err == bufio.ErrBufferFull { + r.rawBuffer = append(r.rawBuffer[:0], line...) + for err == bufio.ErrBufferFull { + line, err = r.r.ReadSlice('\n') + r.rawBuffer = append(r.rawBuffer, line...) + } + line = r.rawBuffer + } + readSize := len(line) + if readSize > 0 && err == io.EOF { + err = nil + // For backwards compatibility, drop trailing \r before EOF. + if line[readSize-1] == '\r' { + line = line[:readSize-1] + } + } + r.numLine++ + r.offset += int64(readSize) + // Normalize \r\n to \n on all input lines. + if n := len(line); n >= 2 && line[n-2] == '\r' && line[n-1] == '\n' { + line[n-2] = '\n' + line = line[:n-1] + } + return line, err +} + +// lengthNL reports the number of bytes for the trailing \n. +func lengthNL(b []byte) int { + if len(b) > 0 && b[len(b)-1] == '\n' { + return 1 + } + return 0 +} + +// nextRune returns the next rune in b or utf8.RuneError. +func nextRune(b []byte) rune { + r, _ := utf8.DecodeRune(b) + return r +} + +func (r *Reader) readRecord(dst []string) ([]string, error) { + if r.Comma == r.Comment || !validDelim(r.Comma) || (r.Comment != 0 && !validDelim(r.Comment)) { + return nil, errInvalidDelim + } + + // Read line (automatically skipping past empty lines and any comments). + var line []byte + var errRead error + for errRead == nil { + line, errRead = r.readLine() + if r.Comment != 0 && nextRune(line) == r.Comment { + line = nil + continue // Skip comment lines + } + if errRead == nil && len(line) == lengthNL(line) { + line = nil + continue // Skip empty lines + } + break + } + if errRead == io.EOF { + return nil, errRead + } + + // Parse each field in the record. + var err error + const quoteLen = len(`"`) + commaLen := utf8.RuneLen(r.Comma) + recLine := r.numLine // Starting line for record + r.recordBuffer = r.recordBuffer[:0] + r.fieldIndexes = r.fieldIndexes[:0] + r.fieldPositions = r.fieldPositions[:0] + pos := position{line: r.numLine, col: 1} +parseField: + for { + if r.TrimLeadingSpace { + i := bytes.IndexFunc(line, func(r rune) bool { + return !unicode.IsSpace(r) + }) + if i < 0 { + i = len(line) + pos.col -= lengthNL(line) + } + line = line[i:] + pos.col += i + } + if len(line) == 0 || line[0] != '"' { + // Non-quoted string field + i := bytes.IndexRune(line, r.Comma) + field := line + if i >= 0 { + field = field[:i] + } else { + field = field[:len(field)-lengthNL(field)] + } + // Check to make sure a quote does not appear in field. + if !r.LazyQuotes { + if j := bytes.IndexByte(field, '"'); j >= 0 { + col := pos.col + j + err = &ParseError{StartLine: recLine, Line: r.numLine, Column: col, Err: ErrBareQuote} + break parseField + } + } + r.recordBuffer = append(r.recordBuffer, field...) + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, pos) + if i >= 0 { + line = line[i+commaLen:] + pos.col += i + commaLen + continue parseField + } + break parseField + } else { + // Quoted string field + fieldPos := pos + line = line[quoteLen:] + pos.col += quoteLen + for { + i := bytes.IndexByte(line, '"') + if i >= 0 { + // Hit next quote. + r.recordBuffer = append(r.recordBuffer, line[:i]...) + line = line[i+quoteLen:] + pos.col += i + quoteLen + switch rn := nextRune(line); { + case rn == '"': + // `""` sequence (append quote). + r.recordBuffer = append(r.recordBuffer, '"') + line = line[quoteLen:] + pos.col += quoteLen + case rn == r.Comma: + // `",` sequence (end of field). + line = line[commaLen:] + pos.col += commaLen + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, fieldPos) + continue parseField + case lengthNL(line) == len(line): + // `"\n` sequence (end of line). + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, fieldPos) + break parseField + case r.LazyQuotes: + // `"` sequence (bare quote). + r.recordBuffer = append(r.recordBuffer, '"') + default: + // `"*` sequence (invalid non-escaped quote). + err = &ParseError{StartLine: recLine, Line: r.numLine, Column: pos.col - quoteLen, Err: ErrQuote} + break parseField + } + } else if len(line) > 0 { + // Hit end of line (copy all data so far). + r.recordBuffer = append(r.recordBuffer, line...) + if errRead != nil { + break parseField + } + pos.col += len(line) + line, errRead = r.readLine() + if len(line) > 0 { + pos.line++ + pos.col = 1 + } + if errRead == io.EOF { + errRead = nil + } + } else { + // Abrupt end of file (EOF or error). + if !r.LazyQuotes && errRead == nil { + err = &ParseError{StartLine: recLine, Line: pos.line, Column: pos.col, Err: ErrQuote} + break parseField + } + r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) + r.fieldPositions = append(r.fieldPositions, fieldPos) + break parseField + } + } + } + } + if err == nil { + err = errRead + } + + // Create a single string and create slices out of it. + // This pins the memory of the fields together, but allocates once. + str := string(r.recordBuffer) // Convert to string once to batch allocations + dst = dst[:0] + if cap(dst) < len(r.fieldIndexes) { + dst = make([]string, len(r.fieldIndexes)) + } + dst = dst[:len(r.fieldIndexes)] + var preIdx int + for i, idx := range r.fieldIndexes { + dst[i] = str[preIdx:idx] + preIdx = idx + } + + // Check or update the expected fields per record. + if r.FieldsPerRecord > 0 { + if len(dst) != r.FieldsPerRecord && err == nil { + err = &ParseError{ + StartLine: recLine, + Line: recLine, + Column: 1, + Err: ErrFieldCount, + } + } + } else if r.FieldsPerRecord == 0 { + r.FieldsPerRecord = len(dst) + } + return dst, err +} diff --git a/internal/util/importutil/import_util.go b/internal/util/importutil/import_util.go index bed5d8aaec..5f8231b616 100644 --- a/internal/util/importutil/import_util.go +++ b/internal/util/importutil/import_util.go @@ -464,7 +464,7 @@ func fillDynamicData(blockData BlockData, collectionSchema *schemapb.CollectionS // tryFlushBlocks does the two things: // 1. if accumulate data of a block exceed blockSize, call callFlushFunc to generate new binlog file -// 2. if total accumulate data exceed maxTotalSize, call callFlushFUnc to flush the biggest block +// 2. if total accumulate data exceed maxTotalSize, call callFlushFunc to flush the biggest block func tryFlushBlocks(ctx context.Context, shardsData []ShardData, collectionSchema *schemapb.CollectionSchema, diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index dc041312a1..d136d164fc 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -38,6 +38,7 @@ import ( const ( JSONFileExt = ".json" NumpyFileExt = ".npy" + CSVFileExt = ".csv" // parsers read JSON/Numpy/CSV files buffer by buffer, this limitation is to define the buffer size. ReadBufferSize = 16 * 1024 * 1024 // 16MB @@ -186,21 +187,21 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) { filePath := filePaths[i] name, fileType := GetFileNameAndExt(filePath) - // only allow json file or numpy file - if fileType != JSONFileExt && fileType != NumpyFileExt { + // only allow json file, numpy file and csv file + if fileType != JSONFileExt && fileType != NumpyFileExt && fileType != CSVFileExt { log.Warn("import wrapper: unsupported file type", zap.String("filePath", filePath)) return false, fmt.Errorf("unsupported file type: '%s'", filePath) } // we use the first file to determine row-based or column-based - if i == 0 && fileType == JSONFileExt { + if i == 0 && (fileType == JSONFileExt || fileType == CSVFileExt) { rowBased = true } // check file type - // row-based only support json type, column-based only support numpy type + // row-based only support json and csv type, column-based only support numpy type if rowBased { - if fileType != JSONFileExt { + if fileType != JSONFileExt && fileType != CSVFileExt { log.Warn("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath)) return rowBased, fmt.Errorf("unsupported file type for row-based mode: '%s'", filePath) } @@ -278,6 +279,12 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error log.Warn("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath)) return err } + } else if fileType == CSVFileExt { + err = p.parseRowBasedCSV(filePath, options.OnlyValidate) + if err != nil { + log.Warn("import wrapper: failed to parse row-based csv file", zap.Error(err), zap.String("filePath", filePath)) + return err + } } // no need to check else, since the fileValidation() already do this // trigger gc after each file finished @@ -459,6 +466,54 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er return nil } +func (p *ImportWrapper) parseRowBasedCSV(filePath string, onlyValidate bool) error { + tr := timerecord.NewTimeRecorder("csv row-based parser: " + filePath) + + file, err := p.chunkManager.Reader(p.ctx, filePath) + if err != nil { + return err + } + defer file.Close() + size, err := p.chunkManager.Size(p.ctx, filePath) + if err != nil { + return err + } + // csv parser + reader := bufio.NewReader(file) + parser, err := NewCSVParser(p.ctx, p.collectionInfo, p.updateProgressPercent) + if err != nil { + return err + } + + // if only validate, we input a empty flushFunc so that the consumer do nothing but only validation. + var flushFunc ImportFlushFunc + if onlyValidate { + flushFunc = func(fields BlockData, shardID int, partitionID int64) error { + return nil + } + } else { + flushFunc = func(fields BlockData, shardID int, partitionID int64) error { + filePaths := []string{filePath} + printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths) + return p.flushFunc(fields, shardID, partitionID) + } + } + + consumer, err := NewCSVRowConsumer(p.ctx, p.collectionInfo, p.rowIDAllocator, p.binlogSize, flushFunc) + if err != nil { + return err + } + + err = parser.ParseRows(&IOReader{r: reader, fileSize: size}, consumer) + if err != nil { + return err + } + p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...) + + tr.Elapse("parsed") + return nil +} + // flushFunc is the callback function for parsers generate segment and save binlog files func (p *ImportWrapper) flushFunc(fields BlockData, shardID int, partitionID int64) error { logFields := []zap.Field{ diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index 10ad8409a6..8f08cd1437 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -326,6 +326,93 @@ func Test_ImportWrapperRowBased(t *testing.T) { }) } +func Test_ImportWrapperRowBased_CSV(t *testing.T) { + err := os.MkdirAll(TempFilesPath, os.ModePerm) + assert.NoError(t, err) + defer os.RemoveAll(TempFilesPath) + paramtable.Init() + + // NewDefaultFactory() use "/tmp/milvus" as default root path, and cannot specify root path + // NewChunkManagerFactory() can specify the root path + f := storage.NewChunkManagerFactory("local", storage.RootPath(TempFilesPath)) + ctx := context.Background() + cm, err := f.NewPersistentStorageChunkManager(ctx) + assert.NoError(t, err) + defer cm.RemoveWithPrefix(ctx, cm.RootPath()) + + idAllocator := newIDAllocator(ctx, t, nil) + content := []byte( + `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector + true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]" + false,11,102,1002,10002,3.15,1.57,No.1,"{""x"": 1}","[201,0]","[0.1,0.2,0.3,0.4]" + true,12,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]"`) + + filePath := TempFilesPath + "rows_1.csv" + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + rowCounter := &rowCounterTest{} + assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) + importResult := &rootcoordpb.ImportResult{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + }, + TaskId: 1, + DatanodeId: 1, + State: commonpb.ImportState_ImportStarted, + Segments: make([]int64, 0), + AutoIds: make([]int64, 0), + RowCount: 0, + } + + reportFunc := func(res *rootcoordpb.ImportResult) error { + return nil + } + collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1}) + assert.NoError(t, err) + + t.Run("success case", func(t *testing.T) { + wrapper := NewImportWrapper(ctx, collectionInfo, 1, ReadBufferSize, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) + files := make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) + assert.NoError(t, err) + assert.Equal(t, 0, rowCounter.rowCount) + + err = wrapper.Import(files, DefaultImportOptions()) + assert.NoError(t, err) + assert.Equal(t, 3, rowCounter.rowCount) + assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) + }) + + t.Run("parse error", func(t *testing.T) { + content := []byte( + `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector + true,false,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]"`) + + filePath = TempFilesPath + "rows_2.csv" + err = cm.Write(ctx, filePath, content) + assert.NoError(t, err) + + importResult.State = commonpb.ImportState_ImportStarted + wrapper := NewImportWrapper(ctx, collectionInfo, 1, ReadBufferSize, idAllocator, cm, importResult, reportFunc) + wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) + files := make([]string, 0) + files = append(files, filePath) + err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) + assert.Error(t, err) + assert.NotEqual(t, commonpb.ImportState_ImportPersisted, importResult.State) + }) + + t.Run("file doesn't exist", func(t *testing.T) { + files := make([]string, 0) + files = append(files, "/dummy/dummy.csv") + wrapper := NewImportWrapper(ctx, collectionInfo, 1, ReadBufferSize, idAllocator, cm, importResult, reportFunc) + err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) + assert.Error(t, err) + }) +} + func Test_ImportWrapperColumnBased_numpy(t *testing.T) { err := os.MkdirAll(TempFilesPath, os.ModePerm) assert.NoError(t, err) diff --git a/internal/util/importutil/json_parser.go b/internal/util/importutil/json_parser.go index 97d30254ca..06df18a023 100644 --- a/internal/util/importutil/json_parser.go +++ b/internal/util/importutil/json_parser.go @@ -110,7 +110,9 @@ func (p *JSONParser) combineDynamicRow(dynamicValues map[string]interface{}, row if value, is := obj.(string); is { // case 1 mp := make(map[string]interface{}) - err := json.Unmarshal([]byte(value), &mp) + desc := json.NewDecoder(strings.NewReader(value)) + desc.UseNumber() + err := desc.Decode(&mp) if err != nil { // invalid input return errors.New("illegal value for dynamic field, not a JSON format string") @@ -192,7 +194,7 @@ func (p *JSONParser) verifyRow(raw interface{}) (map[storage.FieldID]interface{} } } - // combine the redundant pairs into dunamic field(if has) + // combine the redundant pairs into dynamic field(if has) err := p.combineDynamicRow(dynamicValues, row) if err != nil { log.Warn("JSON parser: failed to combine dynamic values", zap.Error(err))