diff --git a/internal/rootcoord/import_manager.go b/internal/rootcoord/import_manager.go index 0e2c64f1da..556f15c90d 100644 --- a/internal/rootcoord/import_manager.go +++ b/internal/rootcoord/import_manager.go @@ -399,10 +399,10 @@ func (m *importManager) isRowbased(files []string) (bool, error) { isRowBased := false for _, filePath := range files { _, fileType := importutil.GetFileNameAndExt(filePath) - if fileType == importutil.JSONFileExt || fileType == importutil.CSVFileExt { + if fileType == importutil.JSONFileExt { isRowBased = true } else if isRowBased { - log.Error("row-based data file type must be JSON or CSV, mixed file types is not allowed", zap.Strings("files", files)) + log.Error("row-based data file type must be JSON, mixed file types is not allowed", zap.Strings("files", files)) return isRowBased, fmt.Errorf("row-based data file type must be JSON or CSV, file type '%s' is not allowed", fileType) } } diff --git a/internal/rootcoord/import_manager_test.go b/internal/rootcoord/import_manager_test.go index 04a6ea0d2e..fe94b2db66 100644 --- a/internal/rootcoord/import_manager_test.go +++ b/internal/rootcoord/import_manager_test.go @@ -1101,26 +1101,6 @@ func TestImportManager_isRowbased(t *testing.T) { rb, err = mgr.isRowbased(files) assert.NoError(t, err) assert.False(t, rb) - - files = []string{"1.csv"} - rb, err = mgr.isRowbased(files) - assert.NoError(t, err) - assert.True(t, rb) - - files = []string{"1.csv", "2.csv"} - rb, err = mgr.isRowbased(files) - assert.Error(t, err) - assert.True(t, rb) - - files = []string{"1.csv", "2.json"} - rb, err = mgr.isRowbased(files) - assert.Error(t, err) - assert.True(t, rb) - - files = []string{"1.csv", "2.npy"} - rb, err = mgr.isRowbased(files) - assert.Error(t, err) - assert.True(t, rb) } func TestImportManager_mergeArray(t *testing.T) { diff --git a/internal/util/importutil/csv_handler.go b/internal/util/importutil/csv_handler.go deleted file mode 100644 index b518f4ab26..0000000000 --- a/internal/util/importutil/csv_handler.go +++ /dev/null @@ -1,458 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importutil - -import ( - "context" - "encoding/json" - "fmt" - "strconv" - "strings" - - "go.uber.org/zap" - - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/allocator" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/common" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -type CSVRowHandler interface { - Handle(row []map[storage.FieldID]string) error -} - -// CSVRowConsumer is row-based csv format consumer class -type CSVRowConsumer struct { - ctx context.Context // for canceling parse process - collectionInfo *CollectionInfo // collection details including schema - rowIDAllocator *allocator.IDAllocator // autoid allocator - validators map[storage.FieldID]*CSVValidator // validators for each field - rowCounter int64 // how many rows have been consumed - shardsData []ShardData // in-memory shards data - blockSize int64 // maximum size of a read block(unit:byte) - autoIDRange []int64 // auto-generated id range, for example: [1, 10, 20, 25] means id from 1 to 10 and 20 to 25 - - callFlushFunc ImportFlushFunc // call back function to flush segment -} - -func NewCSVRowConsumer(ctx context.Context, - collectionInfo *CollectionInfo, - idAlloc *allocator.IDAllocator, - blockSize int64, - flushFunc ImportFlushFunc, -) (*CSVRowConsumer, error) { - if collectionInfo == nil { - log.Warn("CSV row consumer: collection schema is nil") - return nil, merr.WrapErrImportFailed("collection schema is nil") - } - - v := &CSVRowConsumer{ - ctx: ctx, - collectionInfo: collectionInfo, - rowIDAllocator: idAlloc, - validators: make(map[storage.FieldID]*CSVValidator), - rowCounter: 0, - shardsData: make([]ShardData, 0, collectionInfo.ShardNum), - blockSize: blockSize, - autoIDRange: make([]int64, 0), - callFlushFunc: flushFunc, - } - - if err := v.initValidators(collectionInfo.Schema); err != nil { - log.Warn("CSV row consumer: fail to initialize csv row-based consumer", zap.Error(err)) - return nil, merr.WrapErrImportFailed(fmt.Sprintf("fail to initialize csv row-based consumer, error: %v", err)) - } - - for i := 0; i < int(collectionInfo.ShardNum); i++ { - shardData := initShardData(collectionInfo.Schema, collectionInfo.PartitionIDs) - if shardData == nil { - log.Warn("CSV row consumer: fail to initialize in-memory segment data", zap.Int("shardID", i)) - return nil, merr.WrapErrImportFailed(fmt.Sprintf("fail to initialize in-memory segment data for shard id %d", i)) - } - v.shardsData = append(v.shardsData, shardData) - } - - // primary key is autoid, id generator is required - if v.collectionInfo.PrimaryKey.GetAutoID() && idAlloc == nil { - log.Warn("CSV row consumer: ID allocator is nil") - return nil, merr.WrapErrImportFailed("ID allocator is nil") - } - - return v, nil -} - -type CSVValidator struct { - convertFunc func(val string, field storage.FieldData) error // convert data function - isString bool // for string field - fieldName string // field name -} - -func (v *CSVRowConsumer) initValidators(collectionSchema *schemapb.CollectionSchema) error { - if collectionSchema == nil { - return merr.WrapErrImportFailed("collection schema is nil") - } - - validators := v.validators - - for i := 0; i < len(collectionSchema.Fields); i++ { - schema := collectionSchema.Fields[i] - - validators[schema.GetFieldID()] = &CSVValidator{} - validators[schema.GetFieldID()].fieldName = schema.GetName() - validators[schema.GetFieldID()].isString = false - - switch schema.DataType { - // all obj is string type - case schemapb.DataType_Bool: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - var value bool - if err := json.Unmarshal([]byte(str), &value); err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("illegal value '%v' for bool type field '%s'", str, schema.GetName())) - } - field.(*storage.BoolFieldData).Data = append(field.(*storage.BoolFieldData).Data, value) - return nil - } - case schemapb.DataType_Float: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - value, err := parseFloat(str, 32, schema.GetName()) - if err != nil { - return err - } - field.(*storage.FloatFieldData).Data = append(field.(*storage.FloatFieldData).Data, float32(value)) - return nil - } - case schemapb.DataType_Double: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - value, err := parseFloat(str, 64, schema.GetName()) - if err != nil { - return err - } - field.(*storage.DoubleFieldData).Data = append(field.(*storage.DoubleFieldData).Data, value) - return nil - } - case schemapb.DataType_Int8: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - value, err := strconv.ParseInt(str, 0, 8) - if err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse value '%v' for int8 field '%s', error: %v", str, schema.GetName(), err)) - } - field.(*storage.Int8FieldData).Data = append(field.(*storage.Int8FieldData).Data, int8(value)) - return nil - } - case schemapb.DataType_Int16: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - value, err := strconv.ParseInt(str, 0, 16) - if err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse value '%v' for int16 field '%s', error: %v", str, schema.GetName(), err)) - } - field.(*storage.Int16FieldData).Data = append(field.(*storage.Int16FieldData).Data, int16(value)) - return nil - } - case schemapb.DataType_Int32: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - value, err := strconv.ParseInt(str, 0, 32) - if err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse value '%v' for int32 field '%s', error: %v", str, schema.GetName(), err)) - } - field.(*storage.Int32FieldData).Data = append(field.(*storage.Int32FieldData).Data, int32(value)) - return nil - } - case schemapb.DataType_Int64: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - value, err := strconv.ParseInt(str, 0, 64) - if err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse value '%v' for int64 field '%s', error: %v", str, schema.GetName(), err)) - } - field.(*storage.Int64FieldData).Data = append(field.(*storage.Int64FieldData).Data, value) - return nil - } - case schemapb.DataType_BinaryVector: - dim, err := getFieldDimension(schema) - if err != nil { - return err - } - - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - var arr []interface{} - desc := json.NewDecoder(strings.NewReader(str)) - desc.UseNumber() - if err := desc.Decode(&arr); err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("'%v' is not an array for binary vector field '%s'", str, schema.GetName())) - } - - // we use uint8 to represent binary vector in csv file, each uint8 value represents 8 dimensions. - if len(arr)*8 != dim { - return merr.WrapErrImportFailed(fmt.Sprintf("bit size %d doesn't equal to vector dimension %d of field '%s'", len(arr)*8, dim, schema.GetName())) - } - - for i := 0; i < len(arr); i++ { - if num, ok := arr[i].(json.Number); ok { - value, err := strconv.ParseUint(string(num), 0, 8) - if err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse value '%v' for binary vector field '%s', error: %v", num, schema.GetName(), err)) - } - field.(*storage.BinaryVectorFieldData).Data = append(field.(*storage.BinaryVectorFieldData).Data, byte(value)) - } else { - return merr.WrapErrImportFailed(fmt.Sprintf("illegal value '%v' for binary vector field '%s'", str, schema.GetName())) - } - } - - return nil - } - case schemapb.DataType_FloatVector: - dim, err := getFieldDimension(schema) - if err != nil { - return err - } - - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - var arr []interface{} - desc := json.NewDecoder(strings.NewReader(str)) - desc.UseNumber() - if err := desc.Decode(&arr); err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("'%v' is not an array for float vector field '%s'", str, schema.GetName())) - } - - if len(arr) != dim { - return merr.WrapErrImportFailed(fmt.Sprintf("array size %d doesn't equal to vector dimension %d of field '%s'", len(arr), dim, schema.GetName())) - } - - for i := 0; i < len(arr); i++ { - if num, ok := arr[i].(json.Number); ok { - value, err := parseFloat(string(num), 32, schema.GetName()) - if err != nil { - return err - } - field.(*storage.FloatVectorFieldData).Data = append(field.(*storage.FloatVectorFieldData).Data, float32(value)) - } else { - return merr.WrapErrImportFailed(fmt.Sprintf("illegal value '%v' for float vector field '%s'", str, schema.GetName())) - } - } - - return nil - } - case schemapb.DataType_String, schemapb.DataType_VarChar: - validators[schema.GetFieldID()].isString = true - - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - field.(*storage.StringFieldData).Data = append(field.(*storage.StringFieldData).Data, str) - return nil - } - case schemapb.DataType_JSON: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - var dummy interface{} - if err := json.Unmarshal([]byte(str), &dummy); err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse value '%v' for JSON field '%s', error: %v", str, schema.GetName(), err)) - } - field.(*storage.JSONFieldData).Data = append(field.(*storage.JSONFieldData).Data, []byte(str)) - return nil - } - case schemapb.DataType_Array: - validators[schema.GetFieldID()].convertFunc = func(str string, field storage.FieldData) error { - var arr []interface{} - desc := json.NewDecoder(strings.NewReader(str)) - desc.UseNumber() - if err := desc.Decode(&arr); err != nil { - return merr.WrapErrImportFailed(fmt.Sprintf("'%v' is not an array for array field '%s'", str, schema.GetName())) - } - - return getArrayElementData(schema, arr, field) - } - - default: - return merr.WrapErrImportFailed(fmt.Sprintf("unsupport data type: %s", getTypeName(collectionSchema.Fields[i].DataType))) - } - } - return nil -} - -func (v *CSVRowConsumer) IDRange() []int64 { - return v.autoIDRange -} - -func (v *CSVRowConsumer) RowCount() int64 { - return v.rowCounter -} - -func (v *CSVRowConsumer) Handle(rows []map[storage.FieldID]string) error { - if v == nil || v.validators == nil || len(v.validators) == 0 { - log.Warn("CSV row consumer is not initialized") - return merr.WrapErrImportFailed("CSV row consumer is not initialized") - } - // if rows is nil, that means read to end of file, force flush all data - if rows == nil { - err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, true) - log.Info("CSV row consumer finished") - return err - } - - // rows is not nil, flush in necessary: - // 1. data block size larger than v.blockSize will be flushed - // 2. total data size exceeds MaxTotalSizeInMemory, the largest data block will be flushed - err := tryFlushBlocks(v.ctx, v.shardsData, v.collectionInfo.Schema, v.callFlushFunc, v.blockSize, MaxTotalSizeInMemory, false) - if err != nil { - log.Warn("CSV row consumer: try flush data but failed", zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("try flush data but failed, error: %v", err)) - } - - // prepare autoid, no matter int64 or varchar pk, we always generate autoid since the hidden field RowIDField requires them - primaryKeyID := v.collectionInfo.PrimaryKey.FieldID - primaryValidator := v.validators[primaryKeyID] - var rowIDBegin typeutil.UniqueID - var rowIDEnd typeutil.UniqueID - if v.collectionInfo.PrimaryKey.AutoID { - if v.rowIDAllocator == nil { - log.Warn("CSV row consumer: primary keys is auto-generated but IDAllocator is nil") - return merr.WrapErrImportFailed("primary keys is auto-generated but IDAllocator is nil") - } - var err error - rowIDBegin, rowIDEnd, err = v.rowIDAllocator.Alloc(uint32(len(rows))) - if err != nil { - log.Warn("CSV row consumer: failed to generate primary keys", zap.Int("count", len(rows)), zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to generate %d primary keys, error: %v", len(rows), err)) - } - if rowIDEnd-rowIDBegin != int64(len(rows)) { - log.Warn("CSV row consumer: try to generate primary keys but allocated ids are not enough", - zap.Int("count", len(rows)), zap.Int64("generated", rowIDEnd-rowIDBegin)) - return merr.WrapErrImportFailed(fmt.Sprintf("try to generate %d primary keys but only %d keys were allocated", len(rows), rowIDEnd-rowIDBegin)) - } - log.Info("CSV row consumer: auto-generate primary keys", zap.Int64("begin", rowIDBegin), zap.Int64("end", rowIDEnd)) - if primaryValidator.isString { - // if pk is varchar, no need to record auto-generated row ids - log.Warn("CSV row consumer: string type primary key connot be auto-generated") - return merr.WrapErrImportFailed("string type primary key connot be auto-generated") - } - v.autoIDRange = append(v.autoIDRange, rowIDBegin, rowIDEnd) - } - - // consume rows - for i := 0; i < len(rows); i++ { - row := rows[i] - rowNumber := v.rowCounter + int64(i) - - // hash to a shard number - var shardID uint32 - var partitionID int64 - if primaryValidator.isString { - pk := row[primaryKeyID] - - // hash to shard based on pk, hash to partition if partition key exist - hash := typeutil.HashString2Uint32(pk) - shardID = hash % uint32(v.collectionInfo.ShardNum) - partitionID, err = v.hashToPartition(row, rowNumber) - if err != nil { - return err - } - - pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.StringFieldData) - pkArray.Data = append(pkArray.Data, pk) - } else { - var pk int64 - if v.collectionInfo.PrimaryKey.AutoID { - pk = rowIDBegin + int64(i) - } else { - pkStr := row[primaryKeyID] - pk, err = strconv.ParseInt(pkStr, 10, 64) - if err != nil { - log.Warn("CSV row consumer: failed to parse primary key at the row", - zap.String("value", pkStr), zap.Int64("rowNumber", rowNumber), zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse primary key '%s' at the row %d, error: %v", - pkStr, rowNumber, err)) - } - } - - hash, err := typeutil.Hash32Int64(pk) - if err != nil { - log.Warn("CSV row consumer: failed to hash primary key at the row", - zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to hash primary key %d at the row %d, error: %v", pk, rowNumber, err)) - } - - // hash to shard based on pk, hash to partition if partition key exist - shardID = hash % uint32(v.collectionInfo.ShardNum) - partitionID, err = v.hashToPartition(row, rowNumber) - if err != nil { - return err - } - - pkArray := v.shardsData[shardID][partitionID][primaryKeyID].(*storage.Int64FieldData) - pkArray.Data = append(pkArray.Data, pk) - } - rowIDField := v.shardsData[shardID][partitionID][common.RowIDField].(*storage.Int64FieldData) - rowIDField.Data = append(rowIDField.Data, rowIDBegin+int64(i)) - - for fieldID, validator := range v.validators { - if fieldID == v.collectionInfo.PrimaryKey.GetFieldID() { - continue - } - - value := row[fieldID] - if err := validator.convertFunc(value, v.shardsData[shardID][partitionID][fieldID]); err != nil { - log.Warn("CSV row consumer: failed to convert value for field at the row", - zap.String("fieldName", validator.fieldName), zap.Int64("rowNumber", rowNumber), zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to convert value for field '%s' at the row %d, error: %v", - validator.fieldName, rowNumber, err)) - } - } - } - - v.rowCounter += int64(len(rows)) - return nil -} - -// hashToPartition hash partition key to get a partition ID, return the first partition ID if no partition key exist -// CollectionInfo ensures only one partition ID in the PartitionIDs if no partition key exist -func (v *CSVRowConsumer) hashToPartition(row map[storage.FieldID]string, rowNumber int64) (int64, error) { - if v.collectionInfo.PartitionKey == nil { - if len(v.collectionInfo.PartitionIDs) != 1 { - return 0, merr.WrapErrImportFailed(fmt.Sprintf("collection '%s' partition list is empty", v.collectionInfo.Schema.Name)) - } - // no partition key, directly return the target partition id - return v.collectionInfo.PartitionIDs[0], nil - } - - partitionKeyID := v.collectionInfo.PartitionKey.GetFieldID() - partitionKeyValidator := v.validators[partitionKeyID] - value := row[partitionKeyID] - - var hashValue uint32 - if partitionKeyValidator.isString { - hashValue = typeutil.HashString2Uint32(value) - } else { - // parse the value from a string - pk, err := strconv.ParseInt(value, 10, 64) - if err != nil { - log.Warn("CSV row consumer: failed to parse partition key at the row", - zap.String("value", value), zap.Int64("rowNumber", rowNumber), zap.Error(err)) - return 0, merr.WrapErrImportFailed(fmt.Sprintf("failed to parse partition key '%s' at the row %d, error: %v", - value, rowNumber, err)) - } - - hashValue, err = typeutil.Hash32Int64(pk) - if err != nil { - log.Warn("CSV row consumer: failed to hash partition key at the row", - zap.Int64("key", pk), zap.Int64("rowNumber", rowNumber), zap.Error(err)) - return 0, merr.WrapErrImportFailed(fmt.Sprintf("failed to hash partition key %d at the row %d, error: %v", pk, rowNumber, err)) - } - } - - index := int64(hashValue % uint32(len(v.collectionInfo.PartitionIDs))) - return v.collectionInfo.PartitionIDs[index], nil -} diff --git a/internal/util/importutil/csv_handler_test.go b/internal/util/importutil/csv_handler_test.go deleted file mode 100644 index 9590a30a35..0000000000 --- a/internal/util/importutil/csv_handler_test.go +++ /dev/null @@ -1,926 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importutil - -import ( - "context" - "strconv" - "testing" - - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/common" -) - -func Test_CSVRowConsumerNew(t *testing.T) { - ctx := context.Background() - - t.Run("nil schema", func(t *testing.T) { - consumer, err := NewCSVRowConsumer(ctx, nil, nil, 16, nil) - assert.Error(t, err) - assert.Nil(t, consumer) - }) - - t.Run("wrong schema", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ - Name: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 101, - Name: "uid", - IsPrimaryKey: true, - AutoID: false, - DataType: schemapb.DataType_Int64, - }, - }, - } - collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) - assert.NoError(t, err) - - schema.Fields[0].DataType = schemapb.DataType_None - consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) - assert.Error(t, err) - assert.Nil(t, consumer) - }) - - t.Run("primary key is autoid but no IDAllocator", func(t *testing.T) { - schema := &schemapb.CollectionSchema{ - Name: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 101, - Name: "uid", - IsPrimaryKey: true, - AutoID: true, - DataType: schemapb.DataType_Int64, - }, - }, - } - collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) - assert.NoError(t, err) - - consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) - assert.Error(t, err) - assert.Nil(t, consumer) - }) - - t.Run("succeed", func(t *testing.T) { - collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1}) - assert.NoError(t, err) - - consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) - assert.NoError(t, err) - assert.NotNil(t, consumer) - }) -} - -func Test_CSVRowConsumerInitValidators(t *testing.T) { - ctx := context.Background() - consumer := &CSVRowConsumer{ - ctx: ctx, - validators: make(map[int64]*CSVValidator), - } - - collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1}) - assert.NoError(t, err) - schema := collectionInfo.Schema - err = consumer.initValidators(schema) - assert.NoError(t, err) - assert.Equal(t, len(schema.Fields), len(consumer.validators)) - for _, field := range schema.Fields { - fieldID := field.GetFieldID() - assert.Equal(t, field.GetName(), consumer.validators[fieldID].fieldName) - if field.GetDataType() != schemapb.DataType_VarChar && field.GetDataType() != schemapb.DataType_String { - assert.False(t, consumer.validators[fieldID].isString) - } else { - assert.True(t, consumer.validators[fieldID].isString) - } - } - - name2ID := make(map[string]storage.FieldID) - for _, field := range schema.Fields { - name2ID[field.GetName()] = field.GetFieldID() - } - - fields := initBlockData(schema) - assert.NotNil(t, fields) - - checkConvertFunc := func(funcName string, validVal string, invalidVal string) { - id := name2ID[funcName] - v, ok := consumer.validators[id] - assert.True(t, ok) - - fieldData := fields[id] - preNum := fieldData.RowNum() - err = v.convertFunc(validVal, fieldData) - assert.NoError(t, err) - postNum := fieldData.RowNum() - assert.Equal(t, 1, postNum-preNum) - - err = v.convertFunc(invalidVal, fieldData) - assert.Error(t, err) - } - - t.Run("check convert functions", func(t *testing.T) { - // all val is string type - validVal := "true" - invalidVal := "5" - checkConvertFunc("FieldBool", validVal, invalidVal) - - validVal = "100" - invalidVal = "128" - checkConvertFunc("FieldInt8", validVal, invalidVal) - - invalidVal = "65536" - checkConvertFunc("FieldInt16", validVal, invalidVal) - - invalidVal = "2147483648" - checkConvertFunc("FieldInt32", validVal, invalidVal) - - invalidVal = "1.2" - checkConvertFunc("FieldInt64", validVal, invalidVal) - - invalidVal = "dummy" - checkConvertFunc("FieldFloat", validVal, invalidVal) - checkConvertFunc("FieldDouble", validVal, invalidVal) - - // json type - validVal = `{"x": 5, "y": true, "z": "hello"}` - checkConvertFunc("FieldJSON", validVal, "a") - checkConvertFunc("FieldJSON", validVal, "{") - - // the binary vector dimension is 16, shoud input two uint8 values, each value should between 0~255 - validVal = "[100, 101]" - invalidVal = "[100, 1256]" - checkConvertFunc("FieldBinaryVector", validVal, invalidVal) - - invalidVal = "false" - checkConvertFunc("FieldBinaryVector", validVal, invalidVal) - invalidVal = "[100]" - checkConvertFunc("FieldBinaryVector", validVal, invalidVal) - invalidVal = "[100.2, 102.5]" - checkConvertFunc("FieldBinaryVector", validVal, invalidVal) - - // the float vector dimension is 4, each value should be valid float number - validVal = "[1,2,3,4]" - invalidVal = `[1,2,3,"dummy"]` - checkConvertFunc("FieldFloatVector", validVal, invalidVal) - invalidVal = "true" - checkConvertFunc("FieldFloatVector", validVal, invalidVal) - invalidVal = `[1]` - checkConvertFunc("FieldFloatVector", validVal, invalidVal) - - validVal = "[1,2,3,4]" - invalidVal = "[bool, false]" - checkConvertFunc("FieldArray", validVal, invalidVal) - }) - - t.Run("init error cases", func(t *testing.T) { - // schema is nil - err := consumer.initValidators(nil) - assert.Error(t, err) - - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: make([]*schemapb.FieldSchema, 0), - } - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 111, - Name: "FieldFloatVector", - IsPrimaryKey: false, - DataType: schemapb.DataType_FloatVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: common.DimKey, Value: "aa"}, - }, - }) - consumer.validators = make(map[int64]*CSVValidator) - err = consumer.initValidators(schema) - assert.Error(t, err) - - schema.Fields = make([]*schemapb.FieldSchema, 0) - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 110, - Name: "FieldBinaryVector", - IsPrimaryKey: false, - DataType: schemapb.DataType_BinaryVector, - TypeParams: []*commonpb.KeyValuePair{ - {Key: common.DimKey, Value: "aa"}, - }, - }) - - err = consumer.initValidators(schema) - assert.Error(t, err) - - // unsupported data type - schema.Fields = make([]*schemapb.FieldSchema, 0) - schema.Fields = append(schema.Fields, &schemapb.FieldSchema{ - FieldID: 110, - Name: "dummy", - IsPrimaryKey: false, - DataType: schemapb.DataType_None, - }) - - err = consumer.initValidators(schema) - assert.Error(t, err) - }) - - t.Run("json field", func(t *testing.T) { - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 102, - Name: "FieldJSON", - DataType: schemapb.DataType_JSON, - }, - }, - } - consumer.validators = make(map[int64]*CSVValidator) - err = consumer.initValidators(schema) - assert.NoError(t, err) - - v, ok := consumer.validators[102] - assert.True(t, ok) - - fields := initBlockData(schema) - assert.NotNil(t, fields) - fieldData := fields[102] - - err = v.convertFunc("{\"x\": 1, \"y\": 5}", fieldData) - assert.NoError(t, err) - assert.Equal(t, 1, fieldData.RowNum()) - - err = v.convertFunc("{}", fieldData) - assert.NoError(t, err) - assert.Equal(t, 2, fieldData.RowNum()) - - err = v.convertFunc("", fieldData) - assert.Error(t, err) - assert.Equal(t, 2, fieldData.RowNum()) - }) - - t.Run("array field", func(t *testing.T) { - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 113, - Name: "FieldArray", - IsPrimaryKey: false, - DataType: schemapb.DataType_Array, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "max_capacity", Value: "100"}, - }, - ElementType: schemapb.DataType_Bool, - }, - }, - } - consumer.validators = make(map[int64]*CSVValidator) - err = consumer.initValidators(schema) - assert.NoError(t, err) - - v, ok := consumer.validators[113] - assert.True(t, ok) - - fields := initBlockData(schema) - assert.NotNil(t, fields) - fieldData := fields[113] - - err = v.convertFunc("[true, false]", fieldData) - assert.NoError(t, err) - assert.Equal(t, 1, fieldData.RowNum()) - - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 113, - Name: "FieldArray", - IsPrimaryKey: false, - DataType: schemapb.DataType_Array, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "max_capacity", Value: "100"}, - }, - ElementType: schemapb.DataType_Int64, - }, - }, - } - consumer.validators = make(map[int64]*CSVValidator) - err = consumer.initValidators(schema) - assert.NoError(t, err) - - v, ok = consumer.validators[113] - assert.True(t, ok) - - fields = initBlockData(schema) - assert.NotNil(t, fields) - fieldData = fields[113] - - err = v.convertFunc("[1,2,3,4]", fieldData) - assert.NoError(t, err) - assert.Equal(t, 1, fieldData.RowNum()) - - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 113, - Name: "FieldArray", - IsPrimaryKey: false, - DataType: schemapb.DataType_Array, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "max_capacity", Value: "100"}, - }, - ElementType: schemapb.DataType_Float, - }, - }, - } - consumer.validators = make(map[int64]*CSVValidator) - err = consumer.initValidators(schema) - assert.NoError(t, err) - - v, ok = consumer.validators[113] - assert.True(t, ok) - - fields = initBlockData(schema) - assert.NotNil(t, fields) - fieldData = fields[113] - - err = v.convertFunc("[1.1,2.2,3.3,4.4]", fieldData) - assert.NoError(t, err) - assert.Equal(t, 1, fieldData.RowNum()) - - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 113, - Name: "FieldArray", - IsPrimaryKey: false, - DataType: schemapb.DataType_Array, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "max_capacity", Value: "100"}, - }, - ElementType: schemapb.DataType_Double, - }, - }, - } - consumer.validators = make(map[int64]*CSVValidator) - err = consumer.initValidators(schema) - assert.NoError(t, err) - - v, ok = consumer.validators[113] - assert.True(t, ok) - - fields = initBlockData(schema) - assert.NotNil(t, fields) - fieldData = fields[113] - - err = v.convertFunc("[1.2,2.3,3.4,4.5]", fieldData) - assert.NoError(t, err) - assert.Equal(t, 1, fieldData.RowNum()) - - schema = &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - AutoID: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 113, - Name: "FieldArray", - IsPrimaryKey: false, - DataType: schemapb.DataType_Array, - TypeParams: []*commonpb.KeyValuePair{ - {Key: "max_capacity", Value: "100"}, - }, - ElementType: schemapb.DataType_VarChar, - }, - }, - } - consumer.validators = make(map[int64]*CSVValidator) - err = consumer.initValidators(schema) - assert.NoError(t, err) - - v, ok = consumer.validators[113] - assert.True(t, ok) - - fields = initBlockData(schema) - assert.NotNil(t, fields) - fieldData = fields[113] - - err = v.convertFunc(`["abc", "vv"]`, fieldData) - assert.NoError(t, err) - assert.Equal(t, 1, fieldData.RowNum()) - }) -} - -func Test_CSVRowConsumerHandleIntPK(t *testing.T) { - ctx := context.Background() - - t.Run("nil input", func(t *testing.T) { - var consumer *CSVRowConsumer - err := consumer.Handle(nil) - assert.Error(t, err) - }) - - schema := &schemapb.CollectionSchema{ - Name: "schema", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 101, - Name: "FieldInt64", - IsPrimaryKey: true, - AutoID: true, - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 102, - Name: "FieldVarchar", - DataType: schemapb.DataType_VarChar, - }, - { - FieldID: 103, - Name: "FieldFloat", - DataType: schemapb.DataType_Float, - }, - }, - } - createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer { - collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs) - assert.NoError(t, err) - - idAllocator := newIDAllocator(ctx, t, nil) - consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc) - assert.NotNil(t, consumer) - assert.NoError(t, err) - - return consumer - } - - t.Run("auto pk no partition key", func(t *testing.T) { - flushErrFunc := func(fields BlockData, shard int, partID int64) error { - return errors.New("dummy error") - } - - // rows to input - inputRowCount := 100 - input := make([]map[storage.FieldID]string, inputRowCount) - for i := 0; i < inputRowCount; i++ { - input[i] = map[storage.FieldID]string{ - 102: "string", - 103: "122.5", - } - } - - shardNum := int32(2) - partitionID := int64(1) - consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushErrFunc) - consumer.rowIDAllocator = newIDAllocator(ctx, t, errors.New("error")) - - waitFlushRowCount := 10 - fieldData := createFieldsData(schema, waitFlushRowCount) - consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID}) - - // nil input will trigger force flush, flushErrFunc returns error - err := consumer.Handle(nil) - assert.Error(t, err) - - // optional flush, flushErrFunc returns error - err = consumer.Handle(input) - assert.Error(t, err) - - // reset flushFunc - var callTime int32 - var flushedRowCount int - consumer.callFlushFunc = func(fields BlockData, shard int, partID int64) error { - callTime++ - assert.Less(t, int32(shard), shardNum) - assert.Equal(t, partitionID, partID) - assert.Greater(t, len(fields), 0) - for _, v := range fields { - assert.Greater(t, v.RowNum(), 0) - } - flushedRowCount += fields[102].RowNum() - return nil - } - // optional flush succeed, each shard has 10 rows, idErrAllocator returns error - err = consumer.Handle(input) - assert.Error(t, err) - assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount) - assert.Equal(t, shardNum, callTime) - - // optional flush again, large blockSize, nothing flushed, idAllocator returns error - callTime = int32(0) - flushedRowCount = 0 - consumer.shardsData = createShardsData(schema, fieldData, shardNum, []int64{partitionID}) - consumer.rowIDAllocator = nil - consumer.blockSize = 8 * 1024 * 1024 - err = consumer.Handle(input) - assert.Error(t, err) - assert.Equal(t, 0, flushedRowCount) - assert.Equal(t, int32(0), callTime) - - // idAllocator is ok, consume 100 rows, the previous shardsData(10 rows per shard) is flushed - callTime = int32(0) - flushedRowCount = 0 - consumer.blockSize = 1 - consumer.rowIDAllocator = newIDAllocator(ctx, t, nil) - err = consumer.Handle(input) - assert.NoError(t, err) - assert.Equal(t, waitFlushRowCount*int(shardNum), flushedRowCount) - assert.Equal(t, shardNum, callTime) - assert.Equal(t, int64(inputRowCount), consumer.RowCount()) - assert.Equal(t, 2, len(consumer.IDRange())) - assert.Equal(t, int64(1), consumer.IDRange()[0]) - assert.Equal(t, int64(1+inputRowCount), consumer.IDRange()[1]) - - // call handle again, the 100 rows are flushed - callTime = int32(0) - flushedRowCount = 0 - err = consumer.Handle(nil) - assert.NoError(t, err) - assert.Equal(t, inputRowCount, flushedRowCount) - assert.Equal(t, shardNum, callTime) - }) - - schema.Fields[0].AutoID = false - - t.Run("manual pk no partition key", func(t *testing.T) { - shardNum := int32(1) - partitionID := int64(100) - - var callTime int32 - var flushedRowCount int - flushFunc := func(fields BlockData, shard int, partID int64) error { - callTime++ - assert.Less(t, int32(shard), shardNum) - assert.Equal(t, partitionID, partID) - assert.Greater(t, len(fields), 0) - flushedRowCount += fields[102].RowNum() - return nil - } - - consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc) - - // failed to convert pk to int value - input := make([]map[storage.FieldID]string, 1) - input[0] = map[int64]string{ - 101: "abc", - 102: "string", - 103: "11.11", - } - - err := consumer.Handle(input) - assert.Error(t, err) - - // failed to hash to partition - input[0] = map[int64]string{ - 101: "99", - 102: "string", - 103: "11.11", - } - consumer.collectionInfo.PartitionIDs = nil - err = consumer.Handle(input) - assert.Error(t, err) - consumer.collectionInfo.PartitionIDs = []int64{partitionID} - - // failed to convert value - input[0] = map[int64]string{ - 101: "99", - 102: "string", - 103: "abc.11", - } - err = consumer.Handle(input) - assert.Error(t, err) - consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID}) // in-memory data is dirty, reset - - // succeed, consum 1 row - input[0] = map[int64]string{ - 101: "99", - 102: "string", - 103: "11.11", - } - err = consumer.Handle(input) - assert.NoError(t, err) - assert.Equal(t, int64(1), consumer.RowCount()) - assert.Equal(t, 0, len(consumer.IDRange())) - - // call handle again, the 1 row is flushed - callTime = int32(0) - flushedRowCount = 0 - err = consumer.Handle(nil) - assert.NoError(t, err) - assert.Equal(t, 1, flushedRowCount) - assert.Equal(t, shardNum, callTime) - }) - - schema.Fields[1].IsPartitionKey = true - - t.Run("manual pk with partition key", func(t *testing.T) { - // 10 partitions - partitionIDs := make([]int64, 0) - for i := 0; i < 10; i++ { - partitionIDs = append(partitionIDs, int64(i)) - } - - shardNum := int32(2) - var flushedRowCount int - flushFunc := func(fields BlockData, shard int, partID int64) error { - assert.Less(t, int32(shard), shardNum) - assert.Contains(t, partitionIDs, partID) - assert.Greater(t, len(fields), 0) - flushedRowCount += fields[102].RowNum() - return nil - } - - consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc) - - // rows to input - inputRowCount := 100 - input := make([]map[storage.FieldID]string, inputRowCount) - for i := 0; i < inputRowCount; i++ { - input[i] = map[int64]string{ - 101: strconv.Itoa(i), - 102: "partitionKey_" + strconv.Itoa(i), - 103: "6.18", - } - } - - // 100 rows are consumed to different partitions - err := consumer.Handle(input) - assert.NoError(t, err) - assert.Equal(t, int64(inputRowCount), consumer.RowCount()) - - // call handle again, 100 rows are flushed - flushedRowCount = 0 - err = consumer.Handle(nil) - assert.NoError(t, err) - assert.Equal(t, inputRowCount, flushedRowCount) - }) -} - -func Test_CSVRowConsumerHandleVarcharPK(t *testing.T) { - ctx := context.Background() - - schema := &schemapb.CollectionSchema{ - Name: "schema", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 101, - Name: "FieldVarchar", - IsPrimaryKey: true, - AutoID: false, - DataType: schemapb.DataType_VarChar, - }, - { - FieldID: 102, - Name: "FieldInt64", - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 103, - Name: "FieldFloat", - DataType: schemapb.DataType_Float, - }, - }, - } - - createConsumeFunc := func(shardNum int32, partitionIDs []int64, flushFunc ImportFlushFunc) *CSVRowConsumer { - collectionInfo, err := NewCollectionInfo(schema, shardNum, partitionIDs) - assert.NoError(t, err) - - idAllocator := newIDAllocator(ctx, t, nil) - consumer, err := NewCSVRowConsumer(ctx, collectionInfo, idAllocator, 1, flushFunc) - assert.NotNil(t, consumer) - assert.NoError(t, err) - - return consumer - } - - t.Run("no partition key", func(t *testing.T) { - shardNum := int32(2) - partitionID := int64(1) - var callTime int32 - var flushedRowCount int - flushFunc := func(fields BlockData, shard int, partID int64) error { - callTime++ - assert.Less(t, int32(shard), shardNum) - assert.Equal(t, partitionID, partID) - assert.Greater(t, len(fields), 0) - for _, v := range fields { - assert.Greater(t, v.RowNum(), 0) - } - flushedRowCount += fields[102].RowNum() - return nil - } - - consumer := createConsumeFunc(shardNum, []int64{partitionID}, flushFunc) - consumer.shardsData = createShardsData(schema, nil, shardNum, []int64{partitionID}) - - // string type primary key cannot be auto-generated - input := make([]map[storage.FieldID]string, 1) - input[0] = map[storage.FieldID]string{ - 101: "primaryKey_0", - 102: "1", - 103: "1.252", - } - - consumer.collectionInfo.PrimaryKey.AutoID = true - err := consumer.Handle(input) - assert.Error(t, err) - consumer.collectionInfo.PrimaryKey.AutoID = false - - // failed to hash to partition - consumer.collectionInfo.PartitionIDs = nil - err = consumer.Handle(input) - assert.Error(t, err) - consumer.collectionInfo.PartitionIDs = []int64{partitionID} - - // rows to input - inputRowCount := 100 - input = make([]map[storage.FieldID]string, inputRowCount) - for i := 0; i < inputRowCount; i++ { - input[i] = map[int64]string{ - 101: "primaryKey_" + strconv.Itoa(i), - 102: strconv.Itoa(i), - 103: "6.18", - } - } - - err = consumer.Handle(input) - assert.NoError(t, err) - assert.Equal(t, int64(inputRowCount), consumer.RowCount()) - assert.Equal(t, 0, len(consumer.IDRange())) - - // call handle again, 100 rows are flushed - err = consumer.Handle(nil) - assert.NoError(t, err) - assert.Equal(t, inputRowCount, flushedRowCount) - assert.Equal(t, shardNum, callTime) - }) - - schema.Fields[1].IsPartitionKey = true - t.Run("has partition key", func(t *testing.T) { - partitionIDs := make([]int64, 0) - for i := 0; i < 10; i++ { - partitionIDs = append(partitionIDs, int64(i)) - } - - shardNum := int32(2) - var flushedRowCount int - flushFunc := func(fields BlockData, shard int, partID int64) error { - assert.Less(t, int32(shard), shardNum) - assert.Contains(t, partitionIDs, partID) - assert.Greater(t, len(fields), 0) - flushedRowCount += fields[102].RowNum() - return nil - } - - consumer := createConsumeFunc(shardNum, partitionIDs, flushFunc) - - // rows to input - inputRowCount := 100 - input := make([]map[storage.FieldID]string, inputRowCount) - for i := 0; i < inputRowCount; i++ { - input[i] = map[int64]string{ - 101: "primaryKey_" + strconv.Itoa(i), - 102: strconv.Itoa(i), - 103: "6.18", - } - } - - err := consumer.Handle(input) - assert.NoError(t, err) - assert.Equal(t, int64(inputRowCount), consumer.RowCount()) - assert.Equal(t, 0, len(consumer.IDRange())) - - // call handle again, 100 rows are flushed - err = consumer.Handle(nil) - assert.NoError(t, err) - assert.Equal(t, inputRowCount, flushedRowCount) - }) -} - -func Test_CSVRowConsumerHashToPartition(t *testing.T) { - ctx := context.Background() - - schema := &schemapb.CollectionSchema{ - Name: "schema", - Fields: []*schemapb.FieldSchema{ - { - FieldID: 100, - Name: "ID", - IsPrimaryKey: true, - AutoID: false, - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 101, - Name: "FieldVarchar", - DataType: schemapb.DataType_VarChar, - }, - { - FieldID: 102, - Name: "FieldInt64", - DataType: schemapb.DataType_Int64, - }, - }, - } - - partitionID := int64(1) - collectionInfo, err := NewCollectionInfo(schema, 2, []int64{partitionID}) - assert.NoError(t, err) - consumer, err := NewCSVRowConsumer(ctx, collectionInfo, nil, 16, nil) - assert.NoError(t, err) - assert.NotNil(t, consumer) - input := map[int64]string{ - 100: "1", - 101: "abc", - 102: "100", - } - t.Run("no partition key", func(t *testing.T) { - partID, err := consumer.hashToPartition(input, 0) - assert.NoError(t, err) - assert.Equal(t, partitionID, partID) - }) - - t.Run("partition list is empty", func(t *testing.T) { - collectionInfo.PartitionIDs = []int64{} - partID, err := consumer.hashToPartition(input, 0) - assert.Error(t, err) - assert.Equal(t, int64(0), partID) - collectionInfo.PartitionIDs = []int64{partitionID} - }) - - schema.Fields[1].IsPartitionKey = true - err = collectionInfo.resetSchema(schema) - assert.NoError(t, err) - collectionInfo.PartitionIDs = []int64{1, 2, 3} - - t.Run("varchar partition key", func(t *testing.T) { - input = map[int64]string{ - 100: "1", - 101: "abc", - 102: "100", - } - - partID, err := consumer.hashToPartition(input, 0) - assert.NoError(t, err) - assert.Contains(t, collectionInfo.PartitionIDs, partID) - }) - - schema.Fields[1].IsPartitionKey = false - schema.Fields[2].IsPartitionKey = true - err = collectionInfo.resetSchema(schema) - assert.NoError(t, err) - - t.Run("int64 partition key", func(t *testing.T) { - input = map[int64]string{ - 100: "1", - 101: "abc", - 102: "ab0", - } - // parse int failed - partID, err := consumer.hashToPartition(input, 0) - assert.Error(t, err) - assert.Equal(t, int64(0), partID) - - // succeed - input[102] = "100" - partID, err = consumer.hashToPartition(input, 0) - assert.NoError(t, err) - assert.Contains(t, collectionInfo.PartitionIDs, partID) - }) -} diff --git a/internal/util/importutil/csv_parser.go b/internal/util/importutil/csv_parser.go deleted file mode 100644 index fb537c189a..0000000000 --- a/internal/util/importutil/csv_parser.go +++ /dev/null @@ -1,319 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importutil - -import ( - "context" - "encoding/json" - "fmt" - "io" - "strconv" - "strings" - - "go.uber.org/zap" - - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/log" - "github.com/milvus-io/milvus/pkg/util/merr" - "github.com/milvus-io/milvus/pkg/util/typeutil" -) - -type CSVParser struct { - ctx context.Context // for canceling parse process - collectionInfo *CollectionInfo // collection details including schema - bufRowCount int // max rows in a buffer - fieldsName []string // fieldsName(header name) in the csv file - updateProgressFunc func(percent int64) // update working progress percent value -} - -func NewCSVParser(ctx context.Context, collectionInfo *CollectionInfo, updateProgressFunc func(percent int64)) (*CSVParser, error) { - if collectionInfo == nil { - log.Warn("CSV parser: collection schema is nil") - return nil, merr.WrapErrImportFailed("collection schema is nil") - } - - parser := &CSVParser{ - ctx: ctx, - collectionInfo: collectionInfo, - bufRowCount: 1024, - fieldsName: make([]string, 0), - updateProgressFunc: updateProgressFunc, - } - parser.SetBufSize() - return parser, nil -} - -func (p *CSVParser) SetBufSize() { - schema := p.collectionInfo.Schema - sizePerRecord, _ := typeutil.EstimateSizePerRecord(schema) - if sizePerRecord <= 0 { - return - } - - bufRowCount := p.bufRowCount - for { - if bufRowCount*sizePerRecord > ReadBufferSize { - bufRowCount-- - } else { - break - } - } - if bufRowCount <= 0 { - bufRowCount = 1 - } - log.Info("CSV parser: reset bufRowCount", zap.Int("sizePerRecord", sizePerRecord), zap.Int("bufRowCount", bufRowCount)) - p.bufRowCount = bufRowCount -} - -func (p *CSVParser) combineDynamicRow(dynamicValues map[string]string, row map[storage.FieldID]string) error { - if p.collectionInfo.DynamicField == nil { - return nil - } - - dynamicFieldID := p.collectionInfo.DynamicField.GetFieldID() - // combine the dynamic field value - // valid input: - // id,vector,x,$meta id,vector,$meta - // case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}" - // case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}" - // case3: 1,"[]",,"{""x"": 8}" - // case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}" - // case5: 1,"[]",, - value, ok := row[dynamicFieldID] - // ignore empty string field - if value == "" { - ok = false - } - if len(dynamicValues) > 0 { - mp := make(map[string]interface{}) - if ok { - // case 1/2 - // $meta is JSON type field, we first convert it to map[string]interface{} - // then merge other dynamic field into it - desc := json.NewDecoder(strings.NewReader(value)) - desc.UseNumber() - if err := desc.Decode(&mp); err != nil { - log.Warn("CSV parser: illegal value for dynamic field, not a JSON object") - return merr.WrapErrImportFailed("illegal value for dynamic field, not a JSON object") - } - } - // case 4 - for k, v := range dynamicValues { - // ignore empty string field - if v == "" { - continue - } - var value interface{} - - desc := json.NewDecoder(strings.NewReader(v)) - desc.UseNumber() - if err := desc.Decode(&value); err != nil { - // Decode a string will cause error, like "abcd" - mp[k] = v - continue - } - - if num, ok := value.(json.Number); ok { - // Decode may convert "123ab" to 123, so need additional check - if _, err := strconv.ParseFloat(v, 64); err != nil { - mp[k] = v - } else { - mp[k] = num - } - } else if arr, ok := value.([]interface{}); ok { - mp[k] = arr - } else if obj, ok := value.(map[string]interface{}); ok { - mp[k] = obj - } else if b, ok := value.(bool); ok { - mp[k] = b - } - } - bs, err := json.Marshal(mp) - if err != nil { - log.Warn("CSV parser: illegal value for dynamic field, not a JSON object") - return merr.WrapErrImportFailed("illegal value for dynamic field, not a JSON object") - } - row[dynamicFieldID] = string(bs) - } else if !ok && len(dynamicValues) == 0 { - // case 5 - row[dynamicFieldID] = "{}" - } - // else case 3 - - return nil -} - -func (p *CSVParser) verifyRow(raw []string) (map[storage.FieldID]string, error) { - row := make(map[storage.FieldID]string) - dynamicValues := make(map[string]string) - - for i := 0; i < len(p.fieldsName); i++ { - fieldName := p.fieldsName[i] - fieldID, ok := p.collectionInfo.Name2FieldID[fieldName] - - if fieldID == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() { - // primary key is auto-id, no need to provide - log.Warn("CSV parser: the primary key is auto-generated, no need to provide", zap.String("fieldName", fieldName)) - return nil, merr.WrapErrImportFailed(fmt.Sprintf("the primary key '%s' is auto-generated, no need to provide", fieldName)) - } - - if ok { - row[fieldID] = raw[i] - } else if p.collectionInfo.DynamicField != nil { - // collection have dynamic field. put it to dynamicValues - dynamicValues[fieldName] = raw[i] - } else { - // no dynamic field. if user provided redundant field, return error - log.Warn("CSV parser: the field is not defined in collection schema", zap.String("fieldName", fieldName)) - return nil, merr.WrapErrImportFailed(fmt.Sprintf("the field '%s' is not defined in collection schema", fieldName)) - } - } - // some fields not provided? - if len(row) != len(p.collectionInfo.Name2FieldID) { - for k, v := range p.collectionInfo.Name2FieldID { - if p.collectionInfo.DynamicField != nil && v == p.collectionInfo.DynamicField.GetFieldID() { - // ignore dyanmic field, user don't have to provide values for dynamic field - continue - } - - if v == p.collectionInfo.PrimaryKey.GetFieldID() && p.collectionInfo.PrimaryKey.GetAutoID() { - // ignore auto-generaed primary key - continue - } - _, ok := row[v] - if !ok { - // not auto-id primary key, no dynamic field, must provide value - log.Warn("CSV parser: a field value is missed", zap.String("fieldName", k)) - return nil, merr.WrapErrImportFailed(fmt.Sprintf("value of field '%s' is missed", k)) - } - } - } - // combine the redundant pairs into dynamic field(if has) - err := p.combineDynamicRow(dynamicValues, row) - if err != nil { - log.Warn("CSV parser: failed to combine dynamic values", zap.Error(err)) - return nil, err - } - - return row, nil -} - -func (p *CSVParser) ParseRows(reader *IOReader, handle CSVRowHandler) error { - if reader == nil || handle == nil { - log.Warn("CSV Parser: CSV parse handle is nil") - return merr.WrapErrImportFailed("CSV parse handle is nil") - } - // discard bom in the file - RuneScanner := reader.r.(io.RuneScanner) - bom, _, err := RuneScanner.ReadRune() - if err == io.EOF { - log.Info("CSV Parser: row count is 0") - return nil - } - if err != nil { - return err - } - if bom != '\ufeff' { - if err = RuneScanner.UnreadRune(); err != nil { - return err - } - } - r := NewReader(reader.r) - - oldPercent := int64(0) - updateProgress := func() { - if p.updateProgressFunc != nil && reader.fileSize > 0 { - percent := (r.InputOffset() * ProgressValueForPersist) / reader.fileSize - if percent > oldPercent { // avoid too many log - log.Debug("CSV parser: working progress", zap.Int64("offset", r.InputOffset()), - zap.Int64("fileSize", reader.fileSize), zap.Int64("percent", percent)) - } - oldPercent = percent - p.updateProgressFunc(percent) - } - } - isEmpty := true - for { - // read the fields value - fieldsName, err := r.Read() - if err == io.EOF { - break - } else if err != nil { - log.Warn("CSV Parser: failed to parse the field value", zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to read the field value, error: %v", err)) - } - p.fieldsName = fieldsName - // read buffer - buf := make([]map[storage.FieldID]string, 0, p.bufRowCount) - for { - // read the row value - values, err := r.Read() - - if err == io.EOF { - break - } else if err != nil { - log.Warn("CSV parser: failed to parse row value", zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to parse row value, error: %v", err)) - } - - row, err := p.verifyRow(values) - if err != nil { - return err - } - - updateProgress() - - buf = append(buf, row) - if len(buf) >= p.bufRowCount { - isEmpty = false - if err = handle.Handle(buf); err != nil { - log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to convert row value to entity, error: %v", err)) - } - // clean the buffer - buf = make([]map[storage.FieldID]string, 0, p.bufRowCount) - } - } - if len(buf) > 0 { - isEmpty = false - if err = handle.Handle(buf); err != nil { - log.Warn("CSV parser: failed to convert row value to entity", zap.Error(err)) - return merr.WrapErrImportFailed(fmt.Sprintf("failed to convert row value to entity, error: %v", err)) - } - } - - // outside context might be canceled(service stop, or future enhancement for canceling import task) - if isCanceled(p.ctx) { - log.Warn("CSV parser: import task was canceled") - return merr.WrapErrImportFailed("import task was canceled") - } - // nolint - // this break means we require the first row must be fieldsName - break - } - - // empty file is allowed, don't return error - if isEmpty { - log.Info("CSV Parser: row count is 0") - return nil - } - - updateProgress() - - // send nil to notify the handler all have done - return handle.Handle(nil) -} diff --git a/internal/util/importutil/csv_parser_test.go b/internal/util/importutil/csv_parser_test.go deleted file mode 100644 index c93a34d6a2..0000000000 --- a/internal/util/importutil/csv_parser_test.go +++ /dev/null @@ -1,414 +0,0 @@ -// Licensed to the LF AI & Data foundation under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package importutil - -import ( - "context" - "strings" - "testing" - - "github.com/cockroachdb/errors" - "github.com/stretchr/testify/assert" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" - "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" - "github.com/milvus-io/milvus/internal/storage" - "github.com/milvus-io/milvus/pkg/common" -) - -type mockCSVRowConsumer struct { - handleErr error - rows []map[storage.FieldID]string - handleCount int -} - -func (v *mockCSVRowConsumer) Handle(rows []map[storage.FieldID]string) error { - if v.handleErr != nil { - return v.handleErr - } - if rows != nil { - v.rows = append(v.rows, rows...) - } - v.handleCount++ - return nil -} - -func Test_CSVParserAdjustBufSize(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := sampleSchema() - collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) - assert.NoError(t, err) - parser, err := NewCSVParser(ctx, collectionInfo, nil) - assert.NoError(t, err) - assert.NotNil(t, parser) - assert.Greater(t, parser.bufRowCount, 0) - // huge row - schema.Fields[9].TypeParams = []*commonpb.KeyValuePair{ - {Key: common.DimKey, Value: "32768"}, - } - parser, err = NewCSVParser(ctx, collectionInfo, nil) - assert.NoError(t, err) - assert.NotNil(t, parser) - assert.Greater(t, parser.bufRowCount, 0) -} - -func Test_CSVParserParseRows_IntPK(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := sampleSchema() - collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) - assert.NoError(t, err) - parser, err := NewCSVParser(ctx, collectionInfo, nil) - assert.NoError(t, err) - assert.NotNil(t, parser) - - consumer := &mockCSVRowConsumer{ - handleErr: nil, - rows: make([]map[int64]string, 0), - handleCount: 0, - } - - reader := strings.NewReader( - `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector,FieldArray - true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]","[1,2,3,4]"`) - - t.Run("parse success", func(t *testing.T) { - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.NoError(t, err) - - // empty file - reader = strings.NewReader(``) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, consumer) - assert.NoError(t, err) - - // only have headers no value row - reader = strings.NewReader(`FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.NoError(t, err) - - // csv file have bom - reader = strings.NewReader(`\ufeffFieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector`) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.NoError(t, err) - }) - - t.Run("error cases", func(t *testing.T) { - // handler is nil - - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(0)}, nil) - assert.Error(t, err) - - // csv parse error, fields len error - reader := strings.NewReader( - `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector,FieldArray - 0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]","[1,2,3,4]"`) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.Error(t, err) - - // redundant field - reader = strings.NewReader( - `dummy,FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector,FieldArray - 1,true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]","[1,2,3,4]"`) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.Error(t, err) - - // field missed - reader = strings.NewReader( - `FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector,FieldArray - 0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]","[1,2,3,4]"`) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.Error(t, err) - - // handle() error - content := `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector,FieldArray - true,0,100,1000,99999999999999999,3,1,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]","[1,2,3,4]"` - consumer.handleErr = errors.New("error") - reader = strings.NewReader(content) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.Error(t, err) - - // canceled - consumer.handleErr = nil - cancel() - reader = strings.NewReader(content) - err = parser.ParseRows(&IOReader{r: reader, fileSize: int64(100)}, consumer) - assert.Error(t, err) - }) -} - -func Test_CSVParserCombineDynamicRow(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - EnableDynamicField: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 106, - Name: "FieldID", - IsPrimaryKey: true, - AutoID: false, - Description: "int64", - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 113, - Name: "FieldDynamic", - IsPrimaryKey: false, - IsDynamic: true, - Description: "dynamic field", - DataType: schemapb.DataType_JSON, - }, - }, - } - - collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) - assert.NoError(t, err) - parser, err := NewCSVParser(ctx, collectionInfo, nil) - assert.NoError(t, err) - assert.NotNil(t, parser) - - // valid input: - // id,vector,x,$meta id,vector,$meta - // case1: 1,"[]",8,"{""y"": 8}" ==>> 1,"[]","{""y"": 8, ""x"": 8}" - // case2: 1,"[]",8,"{}" ==>> 1,"[]","{""x"": 8}" - // case3: 1,"[]",,"{""x"": 8}" - // case4: 1,"[]",8, ==>> 1,"[]","{""x"": 8}" - // case5: 1,"[]",, - - t.Run("value combined for dynamic field", func(t *testing.T) { - dynamicValues := map[string]string{ - "x": "88", - } - row := map[storage.FieldID]string{ - 106: "1", - 113: `{"y": 8}`, - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Contains(t, row, int64(113)) - assert.Contains(t, row[113], "x") - assert.Contains(t, row[113], "y") - - row = map[storage.FieldID]string{ - 106: "1", - 113: `{}`, - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Contains(t, row, int64(113)) - assert.Contains(t, row[113], "x") - }) - - t.Run("JSON format string/object for dynamic field", func(t *testing.T) { - dynamicValues := map[string]string{} - row := map[storage.FieldID]string{ - 106: "1", - 113: `{"x": 8}`, - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Contains(t, row, int64(113)) - }) - - t.Run("dynamic field is hidden", func(t *testing.T) { - dynamicValues := map[string]string{ - "x": "8", - } - row := map[storage.FieldID]string{ - 106: "1", - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Contains(t, row, int64(113)) - assert.Contains(t, row, int64(113)) - assert.Contains(t, row[113], "x") - }) - - t.Run("no values for dynamic field", func(t *testing.T) { - dynamicValues := map[string]string{} - row := map[storage.FieldID]string{ - 106: "1", - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Contains(t, row, int64(113)) - assert.Equal(t, "{}", row[113]) - }) - - t.Run("empty value for dynamic field", func(t *testing.T) { - dynamicValues := map[string]string{ - "x": "", - } - row := map[storage.FieldID]string{ - 106: "1", - 113: `{"y": 8}`, - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Contains(t, row, int64(113)) - assert.Contains(t, row[113], "y") - assert.NotContains(t, row[113], "x") - - row = map[storage.FieldID]string{ - 106: "1", - 113: "", - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Equal(t, "{}", row[113]) - - dynamicValues = map[string]string{ - "x": "5", - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.Contains(t, row[113], "x") - }) - - t.Run("invalid input for dynamic field", func(t *testing.T) { - dynamicValues := map[string]string{ - "x": "8", - } - row := map[storage.FieldID]string{ - 106: "1", - 113: "5", - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.Error(t, err) - - row = map[storage.FieldID]string{ - 106: "1", - 113: "abc", - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.Error(t, err) - }) - - t.Run("not allow dynamic values if no dynamic field", func(t *testing.T) { - parser.collectionInfo.DynamicField = nil - dynamicValues := map[string]string{ - "x": "8", - } - row := map[storage.FieldID]string{ - 106: "1", - } - err = parser.combineDynamicRow(dynamicValues, row) - assert.NoError(t, err) - assert.NotContains(t, row, int64(113)) - }) -} - -func Test_CSVParserVerifyRow(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - schema := &schemapb.CollectionSchema{ - Name: "schema", - Description: "schema", - EnableDynamicField: true, - Fields: []*schemapb.FieldSchema{ - { - FieldID: 106, - Name: "FieldID", - IsPrimaryKey: true, - AutoID: false, - Description: "int64", - DataType: schemapb.DataType_Int64, - }, - { - FieldID: 113, - Name: "FieldDynamic", - IsPrimaryKey: false, - IsDynamic: true, - Description: "dynamic field", - DataType: schemapb.DataType_JSON, - }, - }, - } - - collectionInfo, err := NewCollectionInfo(schema, 2, []int64{1}) - assert.NoError(t, err) - parser, err := NewCSVParser(ctx, collectionInfo, nil) - assert.NoError(t, err) - assert.NotNil(t, parser) - - t.Run("not auto-id, dynamic field provided", func(t *testing.T) { - parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"} - raw := []string{"1", `{"x": 8}`, "true"} - row, err := parser.verifyRow(raw) - assert.NoError(t, err) - assert.Contains(t, row, int64(106)) - assert.Contains(t, row, int64(113)) - assert.Contains(t, row[113], "x") - assert.Contains(t, row[113], "y") - }) - - t.Run("not auto-id, dynamic field not provided", func(t *testing.T) { - parser.fieldsName = []string{"FieldID"} - raw := []string{"1"} - row, err := parser.verifyRow(raw) - assert.NoError(t, err) - assert.Contains(t, row, int64(106)) - assert.Contains(t, row, int64(113)) - assert.Contains(t, "{}", row[113]) - }) - - t.Run("not auto-id, invalid input dynamic field", func(t *testing.T) { - parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"} - raw := []string{"1", "true", "true"} - _, err = parser.verifyRow(raw) - assert.Error(t, err) - }) - - schema.Fields[0].AutoID = true - err = collectionInfo.resetSchema(schema) - assert.NoError(t, err) - t.Run("no need to provide value for auto-id", func(t *testing.T) { - parser.fieldsName = []string{"FieldID", "FieldDynamic", "y"} - raw := []string{"1", `{"x": 8}`, "true"} - _, err := parser.verifyRow(raw) - assert.Error(t, err) - - parser.fieldsName = []string{"FieldDynamic", "y"} - raw = []string{`{"x": 8}`, "true"} - row, err := parser.verifyRow(raw) - assert.NoError(t, err) - assert.Contains(t, row, int64(113)) - }) - - schema.Fields[1].IsDynamic = false - err = collectionInfo.resetSchema(schema) - assert.NoError(t, err) - t.Run("auto id, no dynamic field", func(t *testing.T) { - parser.fieldsName = []string{"FieldDynamic", "y"} - raw := []string{`{"x": 8}`, "true"} - _, err := parser.verifyRow(raw) - assert.Error(t, err) - - // miss FieldDynamic - parser.fieldsName = []string{} - raw = []string{} - _, err = parser.verifyRow(raw) - assert.Error(t, err) - }) -} diff --git a/internal/util/importutil/csv_reader.go b/internal/util/importutil/csv_reader.go deleted file mode 100644 index decd30c5c4..0000000000 --- a/internal/util/importutil/csv_reader.go +++ /dev/null @@ -1,472 +0,0 @@ -// Copied from go 1.20 as go 1.18 csv reader not implement inputOffset - -package importutil - -// Copyright 2011 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -// Package csv reads and writes comma-separated values (CSV) files. -// There are many kinds of CSV files; this package supports the format -// described in RFC 4180. -// -// A csv file contains zero or more records of one or more fields per record. -// Each record is separated by the newline character. The final record may -// optionally be followed by a newline character. -// -// field1,field2,field3 -// -// White space is considered part of a field. -// -// Carriage returns before newline characters are silently removed. -// -// Blank lines are ignored. A line with only whitespace characters (excluding -// the ending newline character) is not considered a blank line. -// -// Fields which start and stop with the quote character " are called -// quoted-fields. The beginning and ending quote are not part of the -// field. -// -// The source: -// -// normal string,"quoted-field" -// -// results in the fields -// -// {`normal string`, `quoted-field`} -// -// Within a quoted-field a quote character followed by a second quote -// character is considered a single quote. -// -// "the ""word"" is true","a ""quoted-field""" -// -// results in -// -// {`the "word" is true`, `a "quoted-field"`} -// -// Newlines and commas may be included in a quoted-field -// -// "Multi-line -// field","comma is ," -// -// results in -// -// {`Multi-line -// field`, `comma is ,`} - -import ( - "bufio" - "bytes" - "fmt" - "io" - "unicode" - "unicode/utf8" - - "github.com/cockroachdb/errors" - - "github.com/milvus-io/milvus/pkg/util/merr" -) - -// A ParseError is returned for parsing errors. -// Line numbers are 1-indexed and columns are 0-indexed. -type ParseError struct { - StartLine int // Line where the record starts - Line int // Line where the error occurred - Column int // Column (1-based byte index) where the error occurred - Err error // The actual error -} - -func (e *ParseError) Error() string { - if errors.Is(e.Err, ErrFieldCount) { - return fmt.Sprintf("record on line %d: %v", e.Line, e.Err) - } - if e.StartLine != e.Line { - return fmt.Sprintf("record on line %d; parse error on line %d, column %d: %v", e.StartLine, e.Line, e.Column, e.Err) - } - return fmt.Sprintf("parse error on line %d, column %d: %v", e.Line, e.Column, e.Err) -} - -func (e *ParseError) Unwrap() error { return e.Err } - -// These are the errors that can be returned in ParseError.Err. -var ( - ErrBareQuote = merr.WrapErrImportFailed("bare \" in non-quoted-field") - ErrQuote = merr.WrapErrImportFailed("extraneous or missing \" in quoted-field") - ErrFieldCount = merr.WrapErrImportFailed("wrong number of fields") - - // Deprecated: ErrTrailingComma is no longer used. - ErrTrailingComma = merr.WrapErrImportFailed("extra delimiter at end of line") -) - -var errInvalidDelim = merr.WrapErrImportFailed("csv: invalid field or comment delimiter") - -func validDelim(r rune) bool { - return r != 0 && r != '"' && r != '\r' && r != '\n' && utf8.ValidRune(r) && r != utf8.RuneError -} - -// A Reader reads records from a CSV-encoded file. -// -// As returned by NewReader, a Reader expects input conforming to RFC 4180. -// The exported fields can be changed to customize the details before the -// first call to Read or ReadAll. -// -// The Reader converts all \r\n sequences in its input to plain \n, -// including in multiline field values, so that the returned data does -// not depend on which line-ending convention an input file uses. -type Reader struct { - // Comma is the field delimiter. - // It is set to comma (',') by NewReader. - // Comma must be a valid rune and must not be \r, \n, - // or the Unicode replacement character (0xFFFD). - Comma rune - - // Comment, if not 0, is the comment character. Lines beginning with the - // Comment character without preceding whitespace are ignored. - // With leading whitespace the Comment character becomes part of the - // field, even if TrimLeadingSpace is true. - // Comment must be a valid rune and must not be \r, \n, - // or the Unicode replacement character (0xFFFD). - // It must also not be equal to Comma. - Comment rune - - // FieldsPerRecord is the number of expected fields per record. - // If FieldsPerRecord is positive, Read requires each record to - // have the given number of fields. If FieldsPerRecord is 0, Read sets it to - // the number of fields in the first record, so that future records must - // have the same field count. If FieldsPerRecord is negative, no check is - // made and records may have a variable number of fields. - FieldsPerRecord int - - // If LazyQuotes is true, a quote may appear in an unquoted field and a - // non-doubled quote may appear in a quoted field. - LazyQuotes bool - - // If TrimLeadingSpace is true, leading white space in a field is ignored. - // This is done even if the field delimiter, Comma, is white space. - TrimLeadingSpace bool - - // ReuseRecord controls whether calls to Read may return a slice sharing - // the backing array of the previous call's returned slice for performance. - // By default, each call to Read returns newly allocated memory owned by the caller. - ReuseRecord bool - - // Deprecated: TrailingComma is no longer used. - TrailingComma bool - - r *bufio.Reader - - // numLine is the current line being read in the CSV file. - numLine int - - // offset is the input stream byte offset of the current reader position. - offset int64 - - // rawBuffer is a line buffer only used by the readLine method. - rawBuffer []byte - - // recordBuffer holds the unescaped fields, one after another. - // The fields can be accessed by using the indexes in fieldIndexes. - // E.g., For the row `a,"b","c""d",e`, recordBuffer will contain `abc"de` - // and fieldIndexes will contain the indexes [1, 2, 5, 6]. - recordBuffer []byte - - // fieldIndexes is an index of fields inside recordBuffer. - // The i'th field ends at offset fieldIndexes[i] in recordBuffer. - fieldIndexes []int - - // fieldPositions is an index of field positions for the - // last record returned by Read. - fieldPositions []position - - // lastRecord is a record cache and only used when ReuseRecord == true. - lastRecord []string -} - -// NewReader returns a new Reader that reads from r. -func NewReader(r io.Reader) *Reader { - return &Reader{ - Comma: ',', - r: bufio.NewReader(r), - } -} - -// Read reads one record (a slice of fields) from r. -// If the record has an unexpected number of fields, -// Read returns the record along with the error ErrFieldCount. -// If the record contains a field that cannot be parsed, -// Read returns a partial record along with the parse error. -// The partial record contains all fields read before the error. -// If there is no data left to be read, Read returns nil, io.EOF. -// If ReuseRecord is true, the returned slice may be shared -// between multiple calls to Read. -func (r *Reader) Read() (record []string, err error) { - if r.ReuseRecord { - record, err = r.readRecord(r.lastRecord) - r.lastRecord = record - } else { - record, err = r.readRecord(nil) - } - return record, err -} - -// FieldPos returns the line and column corresponding to -// the start of the field with the given index in the slice most recently -// returned by Read. Numbering of lines and columns starts at 1; -// columns are counted in bytes, not runes. -// -// If this is called with an out-of-bounds index, it panics. -func (r *Reader) FieldPos(field int) (line, column int) { - if field < 0 || field >= len(r.fieldPositions) { - panic("out of range index passed to FieldPos") - } - p := &r.fieldPositions[field] - return p.line, p.col -} - -// InputOffset returns the input stream byte offset of the current reader -// position. The offset gives the location of the end of the most recently -// read row and the beginning of the next row. -func (r *Reader) InputOffset() int64 { - return r.offset -} - -// pos holds the position of a field in the current line. -type position struct { - line, col int -} - -// ReadAll reads all the remaining records from r. -// Each record is a slice of fields. -// A successful call returns err == nil, not err == io.EOF. Because ReadAll is -// defined to read until EOF, it does not treat end of file as an error to be -// reported. -func (r *Reader) ReadAll() (records [][]string, err error) { - for { - record, err := r.readRecord(nil) - if err == io.EOF { - return records, nil - } - if err != nil { - return nil, err - } - records = append(records, record) - } -} - -// readLine reads the next line (with the trailing endline). -// If EOF is hit without a trailing endline, it will be omitted. -// If some bytes were read, then the error is never io.EOF. -// The result is only valid until the next call to readLine. -func (r *Reader) readLine() ([]byte, error) { - line, err := r.r.ReadSlice('\n') - if errors.Is(err, bufio.ErrBufferFull) { - r.rawBuffer = append(r.rawBuffer[:0], line...) - for errors.Is(err, bufio.ErrBufferFull) { - line, err = r.r.ReadSlice('\n') - r.rawBuffer = append(r.rawBuffer, line...) - } - line = r.rawBuffer - } - readSize := len(line) - if readSize > 0 && err == io.EOF { - err = nil - // For backwards compatibility, drop trailing \r before EOF. - if line[readSize-1] == '\r' { - line = line[:readSize-1] - } - } - r.numLine++ - r.offset += int64(readSize) - // Normalize \r\n to \n on all input lines. - if n := len(line); n >= 2 && line[n-2] == '\r' && line[n-1] == '\n' { - line[n-2] = '\n' - line = line[:n-1] - } - return line, err -} - -// lengthNL reports the number of bytes for the trailing \n. -func lengthNL(b []byte) int { - if len(b) > 0 && b[len(b)-1] == '\n' { - return 1 - } - return 0 -} - -// nextRune returns the next rune in b or utf8.RuneError. -func nextRune(b []byte) rune { - r, _ := utf8.DecodeRune(b) - return r -} - -func (r *Reader) readRecord(dst []string) ([]string, error) { - if r.Comma == r.Comment || !validDelim(r.Comma) || (r.Comment != 0 && !validDelim(r.Comment)) { - return nil, errInvalidDelim - } - - // Read line (automatically skipping past empty lines and any comments). - var line []byte - var errRead error - for errRead == nil { - line, errRead = r.readLine() - if r.Comment != 0 && nextRune(line) == r.Comment { - line = nil - continue // Skip comment lines - } - if errRead == nil && len(line) == lengthNL(line) { - line = nil - continue // Skip empty lines - } - break - } - if errRead == io.EOF { - return nil, errRead - } - - // Parse each field in the record. - var err error - const quoteLen = len(`"`) - commaLen := utf8.RuneLen(r.Comma) - recLine := r.numLine // Starting line for record - r.recordBuffer = r.recordBuffer[:0] - r.fieldIndexes = r.fieldIndexes[:0] - r.fieldPositions = r.fieldPositions[:0] - pos := position{line: r.numLine, col: 1} -parseField: - for { - if r.TrimLeadingSpace { - i := bytes.IndexFunc(line, func(r rune) bool { - return !unicode.IsSpace(r) - }) - if i < 0 { - i = len(line) - pos.col -= lengthNL(line) - } - line = line[i:] - pos.col += i - } - if len(line) == 0 || line[0] != '"' { - // Non-quoted string field - i := bytes.IndexRune(line, r.Comma) - field := line - if i >= 0 { - field = field[:i] - } else { - field = field[:len(field)-lengthNL(field)] - } - // Check to make sure a quote does not appear in field. - if !r.LazyQuotes { - if j := bytes.IndexByte(field, '"'); j >= 0 { - col := pos.col + j - err = &ParseError{StartLine: recLine, Line: r.numLine, Column: col, Err: ErrBareQuote} - break parseField - } - } - r.recordBuffer = append(r.recordBuffer, field...) - r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) - r.fieldPositions = append(r.fieldPositions, pos) - if i >= 0 { - line = line[i+commaLen:] - pos.col += i + commaLen - continue parseField - } - break parseField - } else { - // Quoted string field - fieldPos := pos - line = line[quoteLen:] - pos.col += quoteLen - for { - i := bytes.IndexByte(line, '"') - if i >= 0 { - // Hit next quote. - r.recordBuffer = append(r.recordBuffer, line[:i]...) - line = line[i+quoteLen:] - pos.col += i + quoteLen - switch rn := nextRune(line); { - case rn == '"': - // `""` sequence (append quote). - r.recordBuffer = append(r.recordBuffer, '"') - line = line[quoteLen:] - pos.col += quoteLen - case rn == r.Comma: - // `",` sequence (end of field). - line = line[commaLen:] - pos.col += commaLen - r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) - r.fieldPositions = append(r.fieldPositions, fieldPos) - continue parseField - case lengthNL(line) == len(line): - // `"\n` sequence (end of line). - r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) - r.fieldPositions = append(r.fieldPositions, fieldPos) - break parseField - case r.LazyQuotes: - // `"` sequence (bare quote). - r.recordBuffer = append(r.recordBuffer, '"') - default: - // `"*` sequence (invalid non-escaped quote). - err = &ParseError{StartLine: recLine, Line: r.numLine, Column: pos.col - quoteLen, Err: ErrQuote} - break parseField - } - } else if len(line) > 0 { - // Hit end of line (copy all data so far). - r.recordBuffer = append(r.recordBuffer, line...) - if errRead != nil { - break parseField - } - pos.col += len(line) - line, errRead = r.readLine() - if len(line) > 0 { - pos.line++ - pos.col = 1 - } - if errRead == io.EOF { - errRead = nil - } - } else { - // Abrupt end of file (EOF or error). - if !r.LazyQuotes && errRead == nil { - err = &ParseError{StartLine: recLine, Line: pos.line, Column: pos.col, Err: ErrQuote} - break parseField - } - r.fieldIndexes = append(r.fieldIndexes, len(r.recordBuffer)) - r.fieldPositions = append(r.fieldPositions, fieldPos) - break parseField - } - } - } - } - if err == nil { - err = errRead - } - - // Create a single string and create slices out of it. - // This pins the memory of the fields together, but allocates once. - str := string(r.recordBuffer) // Convert to string once to batch allocations - dst = dst[:0] - if cap(dst) < len(r.fieldIndexes) { - dst = make([]string, len(r.fieldIndexes)) - } - dst = dst[:len(r.fieldIndexes)] - var preIdx int - for i, idx := range r.fieldIndexes { - dst[i] = str[preIdx:idx] - preIdx = idx - } - - // Check or update the expected fields per record. - if r.FieldsPerRecord > 0 { - if len(dst) != r.FieldsPerRecord && err == nil { - err = &ParseError{ - StartLine: recLine, - Line: recLine, - Column: 1, - Err: ErrFieldCount, - } - } - } else if r.FieldsPerRecord == 0 { - r.FieldsPerRecord = len(dst) - } - return dst, err -} diff --git a/internal/util/importutil/import_wrapper.go b/internal/util/importutil/import_wrapper.go index e2e970b38e..2c5f410169 100644 --- a/internal/util/importutil/import_wrapper.go +++ b/internal/util/importutil/import_wrapper.go @@ -39,7 +39,6 @@ import ( const ( JSONFileExt = ".json" NumpyFileExt = ".npy" - CSVFileExt = ".csv" // parsers read JSON/Numpy/CSV files buffer by buffer, this limitation is to define the buffer size. ReadBufferSize = 16 * 1024 * 1024 // 16MB @@ -189,20 +188,20 @@ func (p *ImportWrapper) fileValidation(filePaths []string) (bool, error) { name, fileType := GetFileNameAndExt(filePath) // only allow json file, numpy file and csv file - if fileType != JSONFileExt && fileType != NumpyFileExt && fileType != CSVFileExt { + if fileType != JSONFileExt && fileType != NumpyFileExt { log.Warn("import wrapper: unsupported file type", zap.String("filePath", filePath)) return false, merr.WrapErrImportFailed(fmt.Sprintf("unsupported file type: '%s'", filePath)) } // we use the first file to determine row-based or column-based - if i == 0 && (fileType == JSONFileExt || fileType == CSVFileExt) { + if i == 0 && fileType == JSONFileExt { rowBased = true } // check file type // row-based only support json and csv type, column-based only support numpy type if rowBased { - if fileType != JSONFileExt && fileType != CSVFileExt { + if fileType != JSONFileExt { log.Warn("import wrapper: unsupported file type for row-based mode", zap.String("filePath", filePath)) return rowBased, merr.WrapErrImportFailed(fmt.Sprintf("unsupported file type for row-based mode: '%s'", filePath)) } @@ -280,12 +279,6 @@ func (p *ImportWrapper) Import(filePaths []string, options ImportOptions) error log.Warn("import wrapper: failed to parse row-based json file", zap.Error(err), zap.String("filePath", filePath)) return err } - } else if fileType == CSVFileExt { - err = p.parseRowBasedCSV(filePath, options.OnlyValidate) - if err != nil { - log.Warn("import wrapper: failed to parse row-based csv file", zap.Error(err), zap.String("filePath", filePath)) - return err - } } // no need to check else, since the fileValidation() already do this // trigger gc after each file finished @@ -467,54 +460,6 @@ func (p *ImportWrapper) parseRowBasedJSON(filePath string, onlyValidate bool) er return nil } -func (p *ImportWrapper) parseRowBasedCSV(filePath string, onlyValidate bool) error { - tr := timerecord.NewTimeRecorder("csv row-based parser: " + filePath) - - file, err := p.chunkManager.Reader(p.ctx, filePath) - if err != nil { - return err - } - defer file.Close() - size, err := p.chunkManager.Size(p.ctx, filePath) - if err != nil { - return err - } - // csv parser - reader := bufio.NewReader(file) - parser, err := NewCSVParser(p.ctx, p.collectionInfo, p.updateProgressPercent) - if err != nil { - return err - } - - // if only validate, we input a empty flushFunc so that the consumer do nothing but only validation. - var flushFunc ImportFlushFunc - if onlyValidate { - flushFunc = func(fields BlockData, shardID int, partitionID int64) error { - return nil - } - } else { - flushFunc = func(fields BlockData, shardID int, partitionID int64) error { - filePaths := []string{filePath} - printFieldsDataInfo(fields, "import wrapper: prepare to flush binlogs", filePaths) - return p.flushFunc(fields, shardID, partitionID) - } - } - - consumer, err := NewCSVRowConsumer(p.ctx, p.collectionInfo, p.rowIDAllocator, p.binlogSize, flushFunc) - if err != nil { - return err - } - - err = parser.ParseRows(&IOReader{r: reader, fileSize: size}, consumer) - if err != nil { - return err - } - p.importResult.AutoIds = append(p.importResult.AutoIds, consumer.IDRange()...) - - tr.Elapse("parsed") - return nil -} - // flushFunc is the callback function for parsers generate segment and save binlog files func (p *ImportWrapper) flushFunc(fields BlockData, shardID int, partitionID int64) error { logFields := []zap.Field{ diff --git a/internal/util/importutil/import_wrapper_test.go b/internal/util/importutil/import_wrapper_test.go index 7ec785e8e3..ec0cc15fef 100644 --- a/internal/util/importutil/import_wrapper_test.go +++ b/internal/util/importutil/import_wrapper_test.go @@ -326,93 +326,6 @@ func Test_ImportWrapperRowBased(t *testing.T) { }) } -func Test_ImportWrapperRowBased_CSV(t *testing.T) { - err := os.MkdirAll(TempFilesPath, os.ModePerm) - assert.NoError(t, err) - defer os.RemoveAll(TempFilesPath) - paramtable.Init() - - // NewDefaultFactory() use "/tmp/milvus" as default root path, and cannot specify root path - // NewChunkManagerFactory() can specify the root path - f := storage.NewChunkManagerFactory("local", storage.RootPath(TempFilesPath)) - ctx := context.Background() - cm, err := f.NewPersistentStorageChunkManager(ctx) - assert.NoError(t, err) - defer cm.RemoveWithPrefix(ctx, cm.RootPath()) - - idAllocator := newIDAllocator(ctx, t, nil) - content := []byte( - `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector,FieldArray - true,10,101,1001,10001,3.14,1.56,No.0,"{""x"": 0}","[200,0]","[0.1,0.2,0.3,0.4]","[1,2,3,4]" - false,11,102,1002,10002,3.15,1.57,No.1,"{""x"": 1}","[201,0]","[0.1,0.2,0.3,0.4]","[5,6,7,8]" - true,12,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]","[9,10,11,12]"`) - - filePath := TempFilesPath + "rows_1.csv" - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - rowCounter := &rowCounterTest{} - assignSegmentFunc, flushFunc, saveSegmentFunc := createMockCallbackFunctions(t, rowCounter) - importResult := &rootcoordpb.ImportResult{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - }, - TaskId: 1, - DatanodeId: 1, - State: commonpb.ImportState_ImportStarted, - Segments: make([]int64, 0), - AutoIds: make([]int64, 0), - RowCount: 0, - } - - reportFunc := func(res *rootcoordpb.ImportResult) error { - return nil - } - collectionInfo, err := NewCollectionInfo(sampleSchema(), 2, []int64{1}) - assert.NoError(t, err) - - t.Run("success case", func(t *testing.T) { - wrapper := NewImportWrapper(ctx, collectionInfo, 1, ReadBufferSize, idAllocator, cm, importResult, reportFunc) - wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) - files := make([]string, 0) - files = append(files, filePath) - err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) - assert.NoError(t, err) - assert.Equal(t, 0, rowCounter.rowCount) - - err = wrapper.Import(files, DefaultImportOptions()) - assert.NoError(t, err) - assert.Equal(t, 3, rowCounter.rowCount) - assert.Equal(t, commonpb.ImportState_ImportPersisted, importResult.State) - }) - - t.Run("parse error", func(t *testing.T) { - content := []byte( - `FieldBool,FieldInt8,FieldInt16,FieldInt32,FieldInt64,FieldFloat,FieldDouble,FieldString,FieldJSON,FieldBinaryVector,FieldFloatVector - true,false,103,1003,10003,3.16,1.58,No.2,"{""x"": 2}","[202,0]","[0.1,0.2,0.3,0.4]"`) - - filePath = TempFilesPath + "rows_2.csv" - err = cm.Write(ctx, filePath, content) - assert.NoError(t, err) - - importResult.State = commonpb.ImportState_ImportStarted - wrapper := NewImportWrapper(ctx, collectionInfo, 1, ReadBufferSize, idAllocator, cm, importResult, reportFunc) - wrapper.SetCallbackFunctions(assignSegmentFunc, flushFunc, saveSegmentFunc) - files := make([]string, 0) - files = append(files, filePath) - err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) - assert.Error(t, err) - assert.NotEqual(t, commonpb.ImportState_ImportPersisted, importResult.State) - }) - - t.Run("file doesn't exist", func(t *testing.T) { - files := make([]string, 0) - files = append(files, "/dummy/dummy.csv") - wrapper := NewImportWrapper(ctx, collectionInfo, 1, ReadBufferSize, idAllocator, cm, importResult, reportFunc) - err = wrapper.Import(files, ImportOptions{OnlyValidate: true}) - assert.Error(t, err) - }) -} - func Test_ImportWrapperColumnBased_numpy(t *testing.T) { err := os.MkdirAll(TempFilesPath, os.ModePerm) assert.NoError(t, err)