mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
fix: empty sparse row in importer (#40585)
fix #40584 parquet bulk writer can not finish 0 dim sparse vector. Signed-off-by: xiaofanluan <xiaofan.luan@zilliz.com>
This commit is contained in:
parent
9f3bd55755
commit
fb48b3c7ac
@ -27,6 +27,7 @@ import (
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/json"
|
||||
"github.com/milvus-io/milvus/pkg/v2/common"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
func TestNewRowParser_Invalid(t *testing.T) {
|
||||
@ -78,6 +79,93 @@ func TestNewRowParser_Invalid(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowParser_Parse_SparseVector(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
FieldID: 1,
|
||||
Name: "id",
|
||||
IsPrimaryKey: true,
|
||||
DataType: schemapb.DataType_Int64,
|
||||
},
|
||||
{
|
||||
FieldID: 2,
|
||||
Name: "sparse_vector",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
header []string
|
||||
row []string
|
||||
wantMaxIdx uint32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty sparse vector",
|
||||
header: []string{"id", "sparse_vector"},
|
||||
row: []string{"1", "{}"},
|
||||
wantMaxIdx: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "key-value format",
|
||||
header: []string{"id", "sparse_vector"},
|
||||
row: []string{"1", "{\"5\":3.14}"},
|
||||
wantMaxIdx: 6, // max index 5 + 1
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiple key-value pairs",
|
||||
header: []string{"id", "sparse_vector"},
|
||||
row: []string{"1", "{\"1\":0.5,\"10\":1.5,\"100\":2.5}"},
|
||||
wantMaxIdx: 101, // max index 100 + 1
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid format",
|
||||
header: []string{"id", "sparse_vector"},
|
||||
row: []string{"1", "{275574541:1.5383775}"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
nullkey := ""
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser, err := NewRowParser(schema, tt.header, nullkey)
|
||||
assert.NoError(t, err)
|
||||
|
||||
row, err := parser.Parse(tt.row)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, row, int64(2)) // sparse_vector field ID
|
||||
|
||||
sparseVec := row[int64(2)].([]byte)
|
||||
|
||||
if tt.wantMaxIdx > 0 {
|
||||
elemCount := len(sparseVec) / 8
|
||||
assert.Greater(t, elemCount, 0)
|
||||
|
||||
// Check the last index matches our expectation
|
||||
lastIdx := typeutil.SparseFloatRowIndexAt(sparseVec, elemCount-1)
|
||||
assert.Equal(t, tt.wantMaxIdx-1, lastIdx)
|
||||
} else {
|
||||
assert.Empty(t, sparseVec)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRowParser_Parse_Valid(t *testing.T) {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
|
||||
@ -668,6 +668,21 @@ func ReadBinaryData(pcr *FieldReader, count int64) (any, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func parseSparseFloatRowVector(str string) ([]byte, uint32, error) {
|
||||
rowVec, err := typeutil.CreateSparseFloatRowFromJSON([]byte(str))
|
||||
if err != nil {
|
||||
return nil, 0, merr.WrapErrImportFailed(fmt.Sprintf("Invalid JSON string for SparseFloatVector: '%s', err = %v", str, err))
|
||||
}
|
||||
elemCount := len(rowVec) / 8
|
||||
maxIdx := uint32(0)
|
||||
|
||||
if elemCount > 0 {
|
||||
maxIdx = typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1) + 1
|
||||
}
|
||||
|
||||
return rowVec, maxIdx, nil
|
||||
}
|
||||
|
||||
func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) {
|
||||
data, err := ReadStringData(pcr, count)
|
||||
if err != nil {
|
||||
@ -676,20 +691,22 @@ func ReadSparseFloatVectorData(pcr *FieldReader, count int64) (any, error) {
|
||||
if data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
byteArr := make([][]byte, 0, count)
|
||||
maxDim := uint32(0)
|
||||
|
||||
for _, str := range data.([]string) {
|
||||
rowVec, err := typeutil.CreateSparseFloatRowFromJSON([]byte(str))
|
||||
rowVec, rowMaxIdx, err := parseSparseFloatRowVector(str)
|
||||
if err != nil {
|
||||
return nil, merr.WrapErrImportFailed(fmt.Sprintf("Invalid JSON string for SparseFloatVector: '%s', err = %v", str, err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
byteArr = append(byteArr, rowVec)
|
||||
elemCount := len(rowVec) / 8
|
||||
maxIdx := typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1)
|
||||
if maxIdx+1 > maxDim {
|
||||
maxDim = maxIdx + 1
|
||||
if rowMaxIdx > maxDim {
|
||||
maxDim = rowMaxIdx
|
||||
}
|
||||
}
|
||||
|
||||
return &storage.SparseFloatVectorFieldData{
|
||||
SparseFloatArray: schemapb.SparseFloatArray{
|
||||
Dim: int64(maxDim),
|
||||
|
||||
92
internal/util/importutilv2/parquet/file_reader_test.go
Normal file
92
internal/util/importutilv2/parquet/file_reader_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
package parquet
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||
)
|
||||
|
||||
// TestParseSparseFloatRowVector tests the parseSparseFloatRowVector function
|
||||
func TestParseSparseFloatRowVector(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMaxIdx uint32
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "empty sparse vector",
|
||||
input: "{}",
|
||||
wantMaxIdx: 0,
|
||||
},
|
||||
{
|
||||
name: "key-value format",
|
||||
input: "{\"275574541\":1.5383775}",
|
||||
wantMaxIdx: 275574542, // max index 275574541 + 1
|
||||
},
|
||||
{
|
||||
name: "multiple key-value pairs",
|
||||
input: "{\"1\":0.5,\"10\":1.5,\"100\":2.5}",
|
||||
wantMaxIdx: 101, // max index 100 + 1
|
||||
},
|
||||
{
|
||||
name: "invalid format - missing braces",
|
||||
input: "\"275574541\":1.5383775",
|
||||
wantErrMsg: "Invalid JSON string for SparseFloatVector",
|
||||
},
|
||||
{
|
||||
name: "invalid JSON format",
|
||||
input: "{275574541:1.5383775}",
|
||||
wantErrMsg: "Invalid JSON string for SparseFloatVector",
|
||||
},
|
||||
{
|
||||
name: "malformed JSON",
|
||||
input: "{\"key\": value}",
|
||||
wantErrMsg: "Invalid JSON string for SparseFloatVector",
|
||||
},
|
||||
{
|
||||
name: "non-numeric index",
|
||||
input: "{\"abc\":1.5}",
|
||||
wantErrMsg: "Invalid JSON string for SparseFloatVector",
|
||||
},
|
||||
{
|
||||
name: "non-numeric value",
|
||||
input: "{\"123\":\"abc\"}",
|
||||
wantErrMsg: "Invalid JSON string for SparseFloatVector",
|
||||
},
|
||||
{
|
||||
name: "negative index",
|
||||
input: "{\"-1\":1.5}",
|
||||
wantErrMsg: "Invalid JSON string for SparseFloatVector",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rowVec, maxIdx, err := parseSparseFloatRowVector(tt.input)
|
||||
|
||||
if tt.wantErrMsg != "" {
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.wantErrMsg)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.wantMaxIdx, maxIdx)
|
||||
|
||||
// Verify the rowVec is properly formatted
|
||||
if maxIdx > 0 {
|
||||
elemCount := len(rowVec) / 8
|
||||
assert.Greater(t, elemCount, 0)
|
||||
|
||||
// Check the last index matches our expectation
|
||||
lastIdx := typeutil.SparseFloatRowIndexAt(rowVec, elemCount-1)
|
||||
assert.Equal(t, tt.wantMaxIdx-1, lastIdx)
|
||||
} else {
|
||||
assert.Empty(t, rowVec)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user