diff --git a/internal/util/importutilv2/csv/row_parser_test.go b/internal/util/importutilv2/csv/row_parser_test.go index 6f0a4948f1..47579410a7 100644 --- a/internal/util/importutilv2/csv/row_parser_test.go +++ b/internal/util/importutilv2/csv/row_parser_test.go @@ -27,6 +27,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/pkg/v2/common" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) func TestNewRowParser_Invalid(t *testing.T) { @@ -78,6 +79,93 @@ func TestNewRowParser_Invalid(t *testing.T) { } } +func TestRowParser_Parse_SparseVector(t *testing.T) { + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + FieldID: 1, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 2, + Name: "sparse_vector", + DataType: schemapb.DataType_SparseFloatVector, + TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}, + }, + }, + } + + tests := []struct { + name string + header []string + row []string + wantMaxIdx uint32 + wantErr bool + }{ + { + name: "empty sparse vector", + header: []string{"id", "sparse_vector"}, + row: []string{"1", "{}"}, + wantMaxIdx: 0, + wantErr: false, + }, + { + name: "key-value format", + header: []string{"id", "sparse_vector"}, + row: []string{"1", "{\"5\":3.14}"}, + wantMaxIdx: 6, // max index 5 + 1 + wantErr: false, + }, + { + name: "multiple key-value pairs", + header: []string{"id", "sparse_vector"}, + row: []string{"1", "{\"1\":0.5,\"10\":1.5,\"100\":2.5}"}, + wantMaxIdx: 101, // max index 100 + 1 + wantErr: false, + }, + { + name: "invalid format", + header: []string{"id", "sparse_vector"}, + row: []string{"1", "{275574541:1.5383775}"}, + wantErr: true, + }, + } + + nullkey := "" + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser, err := NewRowParser(schema, tt.header, nullkey) + assert.NoError(t, err) + + row, err := parser.Parse(tt.row) + + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Contains(t, row, int64(2)) // sparse_vector field ID + + sparseVec := row[int64(2)].([]byte) + + if tt.wantMaxIdx > 0 { + elemCount := len(sparseVec) / 8 + assert.Greater(t, elemCount, 0) + + // Check the last index matches our expectation + lastIdx := typeutil.SparseFloatRowIndexAt(sparseVec, elemCount-1) + assert.Equal(t, tt.wantMaxIdx-1, lastIdx) + } else { + assert.Empty(t, sparseVec) + } + }) + } +} + func TestRowParser_Parse_Valid(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: []*schemapb.FieldSchema{ diff --git a/internal/util/importutilv2/parquet/field_reader.go b/internal/util/importutilv2/parquet/field_reader.go index adb6457f7b..cf243d62b9 100644 --- a/internal/util/importutilv2/parquet/field_reader.go +++ b/internal/util/importutilv2/parquet/field_reader.go @@ -668,6 +668,21 @@ func ReadBinaryData(pcr *FieldReader, count int64) (any, error) { return data, nil } +func parseSparseFloatRowVector(str string) ([]byte, uint32, error) { + rowVec, err := typeutil.CreateSparseFloatRowFromJSON([]byte(str)) + if err != nil { + return nil, 0, merr.WrapErrImportFailed(fmt.Sprintf("Invalid JSON string for SparseFloatVector: '%s', err = %v", str, err)) + } + elemCount := len(rowVec) / 8 + maxIdx := uint32(0) + + if elemCount > 0 { + maxIdx = typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1) + 1 + } + + return rowVec, maxIdx, nil +} + func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) { data, err := ReadStringData(pcr, count) if err != nil { @@ -676,20 +691,22 @@ func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) { if data == nil { return nil, nil } + byteArr := make([][]byte, 0, count) maxDim := uint32(0) + for _, str := range data.([]string) { - rowVec, err := typeutil.CreateSparseFloatRowFromJSON([]byte(str)) + rowVec, rowMaxIdx, err := parseSparseFloatRowVector(str) if err != nil { - return nil, merr.WrapErrImportFailed(fmt.Sprintf("Invalid JSON string for SparseFloatVector: '%s', err = %v", str, err)) + return nil, err } + byteArr = append(byteArr, rowVec) - elemCount := len(rowVec) / 8 - maxIdx := typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1) - if maxIdx+1 > maxDim { - maxDim = maxIdx + 1 + if rowMaxIdx > maxDim { + maxDim = rowMaxIdx } } + return &storage.SparseFloatVectorFieldData{ SparseFloatArray: schemapb.SparseFloatArray{ Dim: int64(maxDim), diff --git a/internal/util/importutilv2/parquet/file_reader_test.go b/internal/util/importutilv2/parquet/file_reader_test.go new file mode 100644 index 0000000000..5ea091c68e --- /dev/null +++ b/internal/util/importutilv2/parquet/file_reader_test.go @@ -0,0 +1,92 @@ +package parquet + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" +) + +// TestParseSparseFloatRowVector tests the parseSparseFloatRowVector function +func TestParseSparseFloatRowVector(t *testing.T) { + tests := []struct { + name string + input string + wantMaxIdx uint32 + wantErrMsg string + }{ + { + name: "empty sparse vector", + input: "{}", + wantMaxIdx: 0, + }, + { + name: "key-value format", + input: "{\"275574541\":1.5383775}", + wantMaxIdx: 275574542, // max index 275574541 + 1 + }, + { + name: "multiple key-value pairs", + input: "{\"1\":0.5,\"10\":1.5,\"100\":2.5}", + wantMaxIdx: 101, // max index 100 + 1 + }, + { + name: "invalid format - missing braces", + input: "\"275574541\":1.5383775", + wantErrMsg: "Invalid JSON string for SparseFloatVector", + }, + { + name: "invalid JSON format", + input: "{275574541:1.5383775}", + wantErrMsg: "Invalid JSON string for SparseFloatVector", + }, + { + name: "malformed JSON", + input: "{\"key\": value}", + wantErrMsg: "Invalid JSON string for SparseFloatVector", + }, + { + name: "non-numeric index", + input: "{\"abc\":1.5}", + wantErrMsg: "Invalid JSON string for SparseFloatVector", + }, + { + name: "non-numeric value", + input: "{\"123\":\"abc\"}", + wantErrMsg: "Invalid JSON string for SparseFloatVector", + }, + { + name: "negative index", + input: "{\"-1\":1.5}", + wantErrMsg: "Invalid JSON string for SparseFloatVector", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rowVec, maxIdx, err := parseSparseFloatRowVector(tt.input) + + if tt.wantErrMsg != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErrMsg) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.wantMaxIdx, maxIdx) + + // Verify the rowVec is properly formatted + if maxIdx > 0 { + elemCount := len(rowVec) / 8 + assert.Greater(t, elemCount, 0) + + // Check the last index matches our expectation + lastIdx := typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1) + assert.Equal(t, tt.wantMaxIdx-1, lastIdx) + } else { + assert.Empty(t, rowVec) + } + }) + } +}