mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-02-02 01:06:41 +08:00
fix: partial upsert with nullable vector fields (#46936)
issue: #46849 related: #45993 Nullable vector fields use compressed storage where null rows don't store actual vector data. This causes issues during partial upsert when merging existing data with upsert data. Changes: - Add nullableVectorMergeContext to track row-to-data index mappings - Add buildNullableVectorIdxMap to build index mapping from ValidData - Add rebuildNullableVectorFieldData to merge vectors in compressed format - Add getVectorDataAtIndex/appendVectorData helpers for all vector types - Fix checkAligned to handle all-null nullable vector fields (Vectors=nil) - Add comprehensive unit tests for new functions - Add e2e tests covering 3 upsert scenarios: - Upsert all rows to null vectors - Upsert null rows to valid vectors - Partial upsert (scalar only, vector preserved) Signed-off-by: marcelo-cjl <marcelo.chen@zilliz.com>
This commit is contained in:
parent
8115f4dec6
commit
4dad649549
@ -365,9 +365,10 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize))
|
||||
|
||||
if len(updateIdxInUpsert) > 0 {
|
||||
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
||||
// but query results fields do not set valid data when returning default value fields,
|
||||
// therefore valid data needs to be manually set to true
|
||||
upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (int64, *schemapb.FieldData) {
|
||||
return field.FieldId, field
|
||||
})
|
||||
nullableVectorContexts := make(map[int64]*nullableVectorMergeContext)
|
||||
for _, fieldData := range existFieldData {
|
||||
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName())
|
||||
if err != nil {
|
||||
@ -375,10 +376,29 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if fieldSchema.GetDefaultValue() != nil {
|
||||
fieldData.ValidData = make([]bool, upsertIDSize)
|
||||
for i := range fieldData.ValidData {
|
||||
fieldData.ValidData[i] = true
|
||||
if fieldSchema.GetNullable() && typeutil.IsVectorType(fieldSchema.GetDataType()) {
|
||||
ctx := &nullableVectorMergeContext{
|
||||
existIdxMap: buildNullableVectorIdxMap(fieldData.GetValidData()),
|
||||
existField: fieldData,
|
||||
mergedValid: make([]bool, 0, len(updateIdxInUpsert)),
|
||||
}
|
||||
|
||||
if upsertField, ok := upsertFieldMap[fieldData.FieldId]; ok {
|
||||
ctx.upsertIdxMap = buildNullableVectorIdxMap(upsertField.GetValidData())
|
||||
ctx.upsertField = upsertField
|
||||
ctx.hasUpsertData = true
|
||||
}
|
||||
|
||||
nullableVectorContexts[fieldData.FieldId] = ctx
|
||||
} else if fieldSchema.GetDefaultValue() != nil {
|
||||
// Note: For fields containing default values, default values need to be set according to valid data during insertion,
|
||||
// but query results fields do not set valid data when returning default value fields,
|
||||
// therefore valid data needs to be manually set to true
|
||||
if len(fieldData.GetValidData()) == 0 {
|
||||
fieldData.ValidData = make([]bool, upsertIDSize)
|
||||
for i := range fieldData.ValidData {
|
||||
fieldData.ValidData[i] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -409,6 +429,28 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
log.Info("update field data failed", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// Collect merged validity for nullable vector fields
|
||||
for _, ctx := range nullableVectorContexts {
|
||||
if ctx.hasUpsertData {
|
||||
ctx.mergedValid = append(ctx.mergedValid, ctx.upsertField.GetValidData()[idx])
|
||||
} else {
|
||||
ctx.mergedValid = append(ctx.mergedValid, ctx.existField.GetValidData()[existIndex])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rebuild nullable vector fields
|
||||
for i, fieldData := range it.insertFieldData {
|
||||
ctx, ok := nullableVectorContexts[fieldData.FieldId]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
newFieldData := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex)
|
||||
if newFieldData != nil {
|
||||
it.insertFieldData[i] = newFieldData
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -482,6 +524,196 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildNullableVectorIdxMap(validData []bool) []int64 {
|
||||
if len(validData) == 0 {
|
||||
return nil
|
||||
}
|
||||
idxMap := make([]int64, len(validData))
|
||||
dataIdx := int64(0)
|
||||
for i, valid := range validData {
|
||||
if valid {
|
||||
idxMap[i] = dataIdx
|
||||
dataIdx++
|
||||
} else {
|
||||
idxMap[i] = -1
|
||||
}
|
||||
}
|
||||
return idxMap
|
||||
}
|
||||
|
||||
type nullableVectorMergeContext struct {
|
||||
existIdxMap []int64 // row index -> data index for exist data
|
||||
upsertIdxMap []int64 // row index -> data index for upsert data
|
||||
mergedValid []bool // merged validity for the output
|
||||
existField *schemapb.FieldData
|
||||
upsertField *schemapb.FieldData
|
||||
hasUpsertData bool // whether upsert request contains this field
|
||||
}
|
||||
|
||||
func rebuildNullableVectorFieldData(
|
||||
ctx *nullableVectorMergeContext,
|
||||
updateIdxInUpsert []int,
|
||||
upsertIDs *schemapb.IDs,
|
||||
existPKToIndex map[interface{}]int,
|
||||
) *schemapb.FieldData {
|
||||
var sourceField *schemapb.FieldData
|
||||
var sourceIdxMap []int64
|
||||
|
||||
if ctx.hasUpsertData {
|
||||
sourceField = ctx.upsertField
|
||||
sourceIdxMap = ctx.upsertIdxMap
|
||||
} else {
|
||||
sourceField = ctx.existField
|
||||
sourceIdxMap = ctx.existIdxMap
|
||||
}
|
||||
|
||||
if sourceField == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
newFieldData := prepareNullableVectorFieldData(sourceField, int64(len(updateIdxInUpsert)))
|
||||
newFieldData.ValidData = ctx.mergedValid
|
||||
|
||||
// Append vectors from source using the correct indices
|
||||
for rowIdx, idx := range updateIdxInUpsert {
|
||||
var sourceRowIdx int
|
||||
if ctx.hasUpsertData {
|
||||
sourceRowIdx = idx // upsertMsg uses updateIdxInUpsert
|
||||
} else {
|
||||
// existFieldData uses existIndex
|
||||
oldPK := typeutil.GetPK(upsertIDs, int64(idx))
|
||||
sourceRowIdx = existPKToIndex[oldPK]
|
||||
}
|
||||
|
||||
// Check if this row is valid (has vector data)
|
||||
if rowIdx < len(ctx.mergedValid) && ctx.mergedValid[rowIdx] {
|
||||
dataIdx := int64(sourceRowIdx)
|
||||
if len(sourceIdxMap) > 0 && sourceRowIdx < len(sourceIdxMap) {
|
||||
dataIdx = sourceIdxMap[sourceRowIdx]
|
||||
}
|
||||
if dataIdx >= 0 {
|
||||
appendSingleVector(newFieldData, sourceField, dataIdx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return newFieldData
|
||||
}
|
||||
|
||||
func prepareNullableVectorFieldData(sample *schemapb.FieldData, capacity int64) *schemapb.FieldData {
|
||||
fd := &schemapb.FieldData{
|
||||
Type: sample.Type,
|
||||
FieldName: sample.FieldName,
|
||||
FieldId: sample.FieldId,
|
||||
IsDynamic: sample.IsDynamic,
|
||||
ValidData: make([]bool, 0, capacity),
|
||||
}
|
||||
|
||||
vectorField := sample.GetVectors()
|
||||
if vectorField == nil {
|
||||
return fd
|
||||
}
|
||||
dim := vectorField.GetDim()
|
||||
vectors := &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
},
|
||||
}
|
||||
|
||||
switch vectorField.Data.(type) {
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: make([]float32, 0, dim*capacity),
|
||||
},
|
||||
}
|
||||
case *schemapb.VectorField_Float16Vector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_Float16Vector{
|
||||
Float16Vector: make([]byte, 0, dim*2*capacity),
|
||||
}
|
||||
case *schemapb.VectorField_Bfloat16Vector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_Bfloat16Vector{
|
||||
Bfloat16Vector: make([]byte, 0, dim*2*capacity),
|
||||
}
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_BinaryVector{
|
||||
BinaryVector: make([]byte, 0, dim/8*capacity),
|
||||
}
|
||||
case *schemapb.VectorField_Int8Vector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_Int8Vector{
|
||||
Int8Vector: make([]byte, 0, dim*capacity),
|
||||
}
|
||||
case *schemapb.VectorField_SparseFloatVector:
|
||||
vectors.Vectors.Data = &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: &schemapb.SparseFloatArray{
|
||||
Contents: make([][]byte, 0, capacity),
|
||||
Dim: vectorField.GetSparseFloatVector().GetDim(),
|
||||
},
|
||||
}
|
||||
}
|
||||
fd.Field = vectors
|
||||
return fd
|
||||
}
|
||||
|
||||
func appendSingleVector(target *schemapb.FieldData, source *schemapb.FieldData, dataIdx int64) {
|
||||
targetVectors := target.GetVectors()
|
||||
sourceVectors := source.GetVectors()
|
||||
if targetVectors == nil || sourceVectors == nil {
|
||||
return
|
||||
}
|
||||
dim := sourceVectors.GetDim()
|
||||
|
||||
switch sv := sourceVectors.Data.(type) {
|
||||
case *schemapb.VectorField_FloatVector:
|
||||
tv := targetVectors.Data.(*schemapb.VectorField_FloatVector)
|
||||
start := dataIdx * dim
|
||||
end := start + dim
|
||||
if end <= int64(len(sv.FloatVector.Data)) {
|
||||
tv.FloatVector.Data = append(tv.FloatVector.Data, sv.FloatVector.Data[start:end]...)
|
||||
}
|
||||
case *schemapb.VectorField_Float16Vector:
|
||||
tv := targetVectors.Data.(*schemapb.VectorField_Float16Vector)
|
||||
unitSize := dim * 2
|
||||
start := dataIdx * unitSize
|
||||
end := start + unitSize
|
||||
if end <= int64(len(sv.Float16Vector)) {
|
||||
tv.Float16Vector = append(tv.Float16Vector, sv.Float16Vector[start:end]...)
|
||||
}
|
||||
case *schemapb.VectorField_Bfloat16Vector:
|
||||
tv := targetVectors.Data.(*schemapb.VectorField_Bfloat16Vector)
|
||||
unitSize := dim * 2
|
||||
start := dataIdx * unitSize
|
||||
end := start + unitSize
|
||||
if end <= int64(len(sv.Bfloat16Vector)) {
|
||||
tv.Bfloat16Vector = append(tv.Bfloat16Vector, sv.Bfloat16Vector[start:end]...)
|
||||
}
|
||||
case *schemapb.VectorField_BinaryVector:
|
||||
tv := targetVectors.Data.(*schemapb.VectorField_BinaryVector)
|
||||
unitSize := dim / 8
|
||||
start := dataIdx * unitSize
|
||||
end := start + unitSize
|
||||
if end <= int64(len(sv.BinaryVector)) {
|
||||
tv.BinaryVector = append(tv.BinaryVector, sv.BinaryVector[start:end]...)
|
||||
}
|
||||
case *schemapb.VectorField_Int8Vector:
|
||||
tv := targetVectors.Data.(*schemapb.VectorField_Int8Vector)
|
||||
start := dataIdx * dim
|
||||
end := start + dim
|
||||
if end <= int64(len(sv.Int8Vector)) {
|
||||
tv.Int8Vector = append(tv.Int8Vector, sv.Int8Vector[start:end]...)
|
||||
}
|
||||
case *schemapb.VectorField_SparseFloatVector:
|
||||
tv := targetVectors.Data.(*schemapb.VectorField_SparseFloatVector)
|
||||
if dataIdx < int64(len(sv.SparseFloatVector.Contents)) {
|
||||
tv.SparseFloatVector.Contents = append(tv.SparseFloatVector.Contents, sv.SparseFloatVector.Contents[dataIdx])
|
||||
// Update dimension if necessary
|
||||
if sv.SparseFloatVector.Dim > tv.SparseFloatVector.Dim {
|
||||
tv.SparseFloatVector.Dim = sv.SparseFloatVector.Dim
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ToCompressedFormatNullable converts the field data from full format nullable to compressed format nullable
|
||||
func ToCompressedFormatNullable(field *schemapb.FieldData) error {
|
||||
if getValidNumber(field.GetValidData()) == len(field.GetValidData()) {
|
||||
|
||||
@ -1825,3 +1825,383 @@ func TestUpsertTask_queryPreExecute_EmptyDataArray(t *testing.T) {
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildNullableVectorIdxMap(t *testing.T) {
|
||||
t.Run("empty validData", func(t *testing.T) {
|
||||
result := buildNullableVectorIdxMap(nil)
|
||||
assert.Nil(t, result)
|
||||
|
||||
result = buildNullableVectorIdxMap([]bool{})
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("all valid", func(t *testing.T) {
|
||||
validData := []bool{true, true, true, true}
|
||||
result := buildNullableVectorIdxMap(validData)
|
||||
assert.Equal(t, []int64{0, 1, 2, 3}, result)
|
||||
})
|
||||
|
||||
t.Run("all null", func(t *testing.T) {
|
||||
validData := []bool{false, false, false, false}
|
||||
result := buildNullableVectorIdxMap(validData)
|
||||
assert.Equal(t, []int64{-1, -1, -1, -1}, result)
|
||||
})
|
||||
|
||||
t.Run("mixed valid and null", func(t *testing.T) {
|
||||
validData := []bool{true, false, true, false, true}
|
||||
result := buildNullableVectorIdxMap(validData)
|
||||
// dataIdx: 0 for row 0, -1 for row 1, 1 for row 2, -1 for row 3, 2 for row 4
|
||||
assert.Equal(t, []int64{0, -1, 1, -1, 2}, result)
|
||||
})
|
||||
|
||||
t.Run("alternating pattern", func(t *testing.T) {
|
||||
validData := []bool{false, true, false, true, false, true}
|
||||
result := buildNullableVectorIdxMap(validData)
|
||||
assert.Equal(t, []int64{-1, 0, -1, 1, -1, 2}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareNullableVectorFieldData(t *testing.T) {
|
||||
dim := int64(4)
|
||||
|
||||
t.Run("FloatVector", func(t *testing.T) {
|
||||
sample := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "float_vec",
|
||||
FieldId: 100,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := prepareNullableVectorFieldData(sample, 10)
|
||||
assert.Equal(t, schemapb.DataType_FloatVector, result.Type)
|
||||
assert.Equal(t, "float_vec", result.FieldName)
|
||||
assert.Equal(t, int64(100), result.FieldId)
|
||||
assert.Equal(t, dim, result.GetVectors().GetDim())
|
||||
assert.Empty(t, result.GetVectors().GetFloatVector().Data)
|
||||
})
|
||||
|
||||
t.Run("Float16Vector", func(t *testing.T) {
|
||||
sample := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_Float16Vector,
|
||||
FieldName: "fp16_vec",
|
||||
FieldId: 101,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_Float16Vector{Float16Vector: []byte{1, 2, 3, 4, 5, 6, 7, 8}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := prepareNullableVectorFieldData(sample, 10)
|
||||
assert.Equal(t, schemapb.DataType_Float16Vector, result.Type)
|
||||
assert.Equal(t, dim, result.GetVectors().GetDim())
|
||||
assert.Empty(t, result.GetVectors().GetFloat16Vector())
|
||||
})
|
||||
|
||||
t.Run("BFloat16Vector", func(t *testing.T) {
|
||||
sample := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BFloat16Vector,
|
||||
FieldName: "bf16_vec",
|
||||
FieldId: 102,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_Bfloat16Vector{Bfloat16Vector: []byte{1, 2, 3, 4, 5, 6, 7, 8}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := prepareNullableVectorFieldData(sample, 10)
|
||||
assert.Equal(t, schemapb.DataType_BFloat16Vector, result.Type)
|
||||
assert.Equal(t, dim, result.GetVectors().GetDim())
|
||||
assert.Empty(t, result.GetVectors().GetBfloat16Vector())
|
||||
})
|
||||
|
||||
t.Run("BinaryVector", func(t *testing.T) {
|
||||
sample := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
FieldName: "binary_vec",
|
||||
FieldId: 103,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: int64(32), // binary vector dim is in bits
|
||||
Data: &schemapb.VectorField_BinaryVector{BinaryVector: []byte{1, 2, 3, 4}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := prepareNullableVectorFieldData(sample, 10)
|
||||
assert.Equal(t, schemapb.DataType_BinaryVector, result.Type)
|
||||
assert.Equal(t, int64(32), result.GetVectors().GetDim())
|
||||
assert.Empty(t, result.GetVectors().GetBinaryVector())
|
||||
})
|
||||
|
||||
t.Run("SparseFloatVector", func(t *testing.T) {
|
||||
sample := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_SparseFloatVector,
|
||||
FieldName: "sparse_vec",
|
||||
FieldId: 104,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 100,
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: &schemapb.SparseFloatArray{
|
||||
Dim: 100,
|
||||
Contents: [][]byte{{1, 2, 3}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := prepareNullableVectorFieldData(sample, 10)
|
||||
assert.Equal(t, schemapb.DataType_SparseFloatVector, result.Type)
|
||||
assert.Empty(t, result.GetVectors().GetSparseFloatVector().Contents)
|
||||
})
|
||||
|
||||
t.Run("nil vectors", func(t *testing.T) {
|
||||
sample := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "empty_vec",
|
||||
FieldId: 105,
|
||||
}
|
||||
result := prepareNullableVectorFieldData(sample, 10)
|
||||
assert.Equal(t, schemapb.DataType_FloatVector, result.Type)
|
||||
assert.Nil(t, result.GetVectors())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAppendSingleVector(t *testing.T) {
|
||||
t.Run("FloatVector", func(t *testing.T) {
|
||||
dim := int64(4)
|
||||
source := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}}, // 2 vectors
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
target := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
appendSingleVector(target, source, 0)
|
||||
assert.Equal(t, []float32{1, 2, 3, 4}, target.GetVectors().GetFloatVector().Data)
|
||||
|
||||
appendSingleVector(target, source, 1)
|
||||
assert.Equal(t, []float32{1, 2, 3, 4, 5, 6, 7, 8}, target.GetVectors().GetFloatVector().Data)
|
||||
})
|
||||
|
||||
t.Run("Float16Vector", func(t *testing.T) {
|
||||
dim := int64(2)
|
||||
source := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_Float16Vector{Float16Vector: []byte{1, 2, 3, 4, 5, 6, 7, 8}}, // 2 vectors (dim*2 bytes each)
|
||||
},
|
||||
},
|
||||
}
|
||||
target := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_Float16Vector{Float16Vector: []byte{}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
appendSingleVector(target, source, 0)
|
||||
assert.Equal(t, []byte{1, 2, 3, 4}, target.GetVectors().GetFloat16Vector())
|
||||
|
||||
appendSingleVector(target, source, 1)
|
||||
assert.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8}, target.GetVectors().GetFloat16Vector())
|
||||
})
|
||||
|
||||
t.Run("BinaryVector", func(t *testing.T) {
|
||||
dim := int64(16) // 16 bits = 2 bytes per vector
|
||||
source := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_BinaryVector{BinaryVector: []byte{1, 2, 3, 4}}, // 2 vectors
|
||||
},
|
||||
},
|
||||
}
|
||||
target := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_BinaryVector{BinaryVector: []byte{}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
appendSingleVector(target, source, 0)
|
||||
assert.Equal(t, []byte{1, 2}, target.GetVectors().GetBinaryVector())
|
||||
|
||||
appendSingleVector(target, source, 1)
|
||||
assert.Equal(t, []byte{1, 2, 3, 4}, target.GetVectors().GetBinaryVector())
|
||||
})
|
||||
|
||||
t.Run("SparseFloatVector", func(t *testing.T) {
|
||||
source := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 100,
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: &schemapb.SparseFloatArray{
|
||||
Dim: 100,
|
||||
Contents: [][]byte{{1, 2}, {3, 4}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
target := &schemapb.FieldData{
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: 100,
|
||||
Data: &schemapb.VectorField_SparseFloatVector{
|
||||
SparseFloatVector: &schemapb.SparseFloatArray{
|
||||
Dim: 0,
|
||||
Contents: [][]byte{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
appendSingleVector(target, source, 0)
|
||||
assert.Equal(t, [][]byte{{1, 2}}, target.GetVectors().GetSparseFloatVector().Contents)
|
||||
|
||||
appendSingleVector(target, source, 1)
|
||||
assert.Equal(t, [][]byte{{1, 2}, {3, 4}}, target.GetVectors().GetSparseFloatVector().Contents)
|
||||
})
|
||||
|
||||
t.Run("nil vectors", func(t *testing.T) {
|
||||
source := &schemapb.FieldData{}
|
||||
target := &schemapb.FieldData{}
|
||||
// Should not panic
|
||||
appendSingleVector(target, source, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRebuildNullableVectorFieldData(t *testing.T) {
|
||||
dim := int64(4)
|
||||
|
||||
t.Run("with upsert data - all valid", func(t *testing.T) {
|
||||
upsertField := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vec",
|
||||
FieldId: 100,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}},
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true, true, true},
|
||||
}
|
||||
|
||||
ctx := &nullableVectorMergeContext{
|
||||
upsertIdxMap: []int64{0, 1, 2},
|
||||
upsertField: upsertField,
|
||||
hasUpsertData: true,
|
||||
mergedValid: []bool{true, true},
|
||||
}
|
||||
|
||||
updateIdxInUpsert := []int{0, 2}
|
||||
upsertIDs := &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}
|
||||
existPKToIndex := map[interface{}]int{}
|
||||
|
||||
result := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, []bool{true, true}, result.ValidData)
|
||||
// Should have 2 vectors: index 0 and index 2 from source
|
||||
assert.Equal(t, []float32{1, 2, 3, 4, 9, 10, 11, 12}, result.GetVectors().GetFloatVector().Data)
|
||||
})
|
||||
|
||||
t.Run("with upsert data - mixed valid/null", func(t *testing.T) {
|
||||
upsertField := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vec",
|
||||
FieldId: 100,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}}}, // only 2 vectors (compressed)
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true, false, true}, // row 0 valid, row 1 null, row 2 valid
|
||||
}
|
||||
|
||||
ctx := &nullableVectorMergeContext{
|
||||
upsertIdxMap: []int64{0, -1, 1}, // row 0 -> data 0, row 1 -> null, row 2 -> data 1
|
||||
upsertField: upsertField,
|
||||
hasUpsertData: true,
|
||||
mergedValid: []bool{true, false}, // result: row 0 valid, row 1 null
|
||||
}
|
||||
|
||||
updateIdxInUpsert := []int{0, 1}
|
||||
upsertIDs := &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}
|
||||
existPKToIndex := map[interface{}]int{}
|
||||
|
||||
result := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, []bool{true, false}, result.ValidData)
|
||||
// Should have only 1 vector (row 0), row 1 is null
|
||||
assert.Equal(t, []float32{1, 2, 3, 4}, result.GetVectors().GetFloatVector().Data)
|
||||
})
|
||||
|
||||
t.Run("without upsert data - use exist data", func(t *testing.T) {
|
||||
existField := &schemapb.FieldData{
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
FieldName: "vec",
|
||||
FieldId: 100,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Dim: dim,
|
||||
Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}}},
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true, true},
|
||||
}
|
||||
|
||||
ctx := &nullableVectorMergeContext{
|
||||
existIdxMap: []int64{0, 1},
|
||||
existField: existField,
|
||||
hasUpsertData: false,
|
||||
mergedValid: []bool{true, true},
|
||||
}
|
||||
|
||||
updateIdxInUpsert := []int{0, 1}
|
||||
upsertIDs := &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{10, 20}}}}
|
||||
existPKToIndex := map[interface{}]int{int64(10): 0, int64(20): 1}
|
||||
|
||||
result := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, []bool{true, true}, result.ValidData)
|
||||
assert.Equal(t, []float32{1, 2, 3, 4, 5, 6, 7, 8}, result.GetVectors().GetFloatVector().Data)
|
||||
})
|
||||
|
||||
t.Run("nil source field", func(t *testing.T) {
|
||||
ctx := &nullableVectorMergeContext{
|
||||
hasUpsertData: false,
|
||||
existField: nil,
|
||||
mergedValid: []bool{true},
|
||||
}
|
||||
|
||||
result := rebuildNullableVectorFieldData(ctx, []int{0}, nil, nil)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
@ -226,6 +226,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if field.GetVectors() == nil {
|
||||
if expectedRows != 0 {
|
||||
return errNumRowsMismatch(field.GetFieldName(), 0)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
dim, err := typeutil.GetDim(f)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -240,7 +248,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return errDimMismatch(field.GetFieldName(), dataDim, dim)
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
@ -251,6 +258,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if field.GetVectors() == nil {
|
||||
if expectedRows != 0 {
|
||||
return errNumRowsMismatch(field.GetFieldName(), 0)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
dim, err := typeutil.GetDim(f)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -265,7 +280,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
@ -276,6 +290,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if field.GetVectors() == nil {
|
||||
if expectedRows != 0 {
|
||||
return errNumRowsMismatch(field.GetFieldName(), 0)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
dim, err := typeutil.GetDim(f)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -290,7 +312,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
@ -301,6 +322,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if field.GetVectors() == nil {
|
||||
if expectedRows != 0 {
|
||||
return errNumRowsMismatch(field.GetFieldName(), 0)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
dim, err := typeutil.GetDim(f)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -315,7 +344,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
@ -325,8 +353,16 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents))
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if field.GetVectors() == nil || field.GetVectors().GetSparseFloatVector() == nil {
|
||||
if expectedRows != 0 {
|
||||
return errNumRowsMismatch(field.GetFieldName(), 0)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents))
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
@ -337,6 +373,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return err
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if field.GetVectors() == nil {
|
||||
if expectedRows != 0 {
|
||||
return errNumRowsMismatch(field.GetFieldName(), 0)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
dim, err := typeutil.GetDim(f)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -351,7 +395,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil
|
||||
return errDimMismatch(field.GetFieldName(), dataDim, dim)
|
||||
}
|
||||
|
||||
expectedRows := getExpectedVectorRows(field, f)
|
||||
if n != expectedRows {
|
||||
return errNumRowsMismatch(field.GetFieldName(), n)
|
||||
}
|
||||
|
||||
@ -1998,6 +1998,292 @@ func Test_validateUtil_checkAligned(t *testing.T) {
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable float vector all null", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
ValidData: []bool{false, false, false},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 3)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable float vector partial null", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 2 vectors of dim 8
|
||||
},
|
||||
},
|
||||
Dim: 8,
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true, false, true, false},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 4)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable float vector mismatch valid count", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_FloatVector,
|
||||
Field: &schemapb.FieldData_Vectors{
|
||||
Vectors: &schemapb.VectorField{
|
||||
Data: &schemapb.VectorField_FloatVector{
|
||||
FloatVector: &schemapb.FloatArray{
|
||||
Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}, // only 1 vector
|
||||
},
|
||||
},
|
||||
Dim: 8,
|
||||
},
|
||||
},
|
||||
ValidData: []bool{true, false, true, false}, // 2 valid
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 4)
|
||||
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable binary vector all null", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_BinaryVector,
|
||||
ValidData: []bool{false, false, false},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_BinaryVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 3)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable float16 vector all null", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_Float16Vector,
|
||||
ValidData: []bool{false, false, false},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_Float16Vector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 3)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable bfloat16 vector all null", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_BFloat16Vector,
|
||||
ValidData: []bool{false, false, false},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_BFloat16Vector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 3)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable int8 vector all null", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_Int8Vector,
|
||||
ValidData: []bool{false, false, false},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_Int8Vector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: common.DimKey,
|
||||
Value: "8",
|
||||
},
|
||||
},
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 3)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("nullable sparse vector all null", func(t *testing.T) {
|
||||
data := []*schemapb.FieldData{
|
||||
{
|
||||
FieldName: "test",
|
||||
Type: schemapb.DataType_SparseFloatVector,
|
||||
ValidData: []bool{false, false, false},
|
||||
},
|
||||
}
|
||||
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{
|
||||
Name: "test",
|
||||
DataType: schemapb.DataType_SparseFloatVector,
|
||||
Nullable: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
h, err := typeutil.CreateSchemaHelper(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
v := newValidateUtil()
|
||||
|
||||
err = v.checkAligned(data, h, 3)
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_validateUtil_Validate(t *testing.T) {
|
||||
|
||||
@ -1732,43 +1732,44 @@ func TestNullableVectorUpsert(t *testing.T) {
|
||||
ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout)
|
||||
mc := hp.CreateDefaultMilvusClient(ctx, t)
|
||||
|
||||
collName := common.GenRandomString("nullable_vec_ups", 5)
|
||||
// Create collection with pk, scalar, and nullable vector fields
|
||||
collName := common.GenRandomString("nullable_vec_upsert", 5)
|
||||
pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true)
|
||||
if autoID {
|
||||
pkField.AutoID = true
|
||||
}
|
||||
scalarField := entity.NewField().WithName("scalar").WithDataType(entity.FieldTypeInt32).WithNullable(true)
|
||||
vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim).WithNullable(true)
|
||||
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField)
|
||||
schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(scalarField).WithField(vecField)
|
||||
|
||||
err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// insert initial data with 50% null
|
||||
// Insert initial data: 100 rows
|
||||
// Rows 0 to nullPercent-1: valid vector, scalar = i*10
|
||||
// Rows nullPercent to nb-1: null vector, scalar = i*10
|
||||
nb := 100
|
||||
nullPercent := 50
|
||||
validData := make([]bool, nb)
|
||||
validCount := 0
|
||||
for i := range nb {
|
||||
validData[i] = (i % 100) >= nullPercent
|
||||
validData[i] = i < nullPercent
|
||||
if validData[i] {
|
||||
validCount++
|
||||
}
|
||||
}
|
||||
|
||||
pkData := make([]int64, nb)
|
||||
scalarData := make([]int32, nb)
|
||||
scalarValidData := make([]bool, nb)
|
||||
for i := range nb {
|
||||
pkData[i] = int64(i)
|
||||
scalarData[i] = int32(i * 10)
|
||||
scalarValidData[i] = true
|
||||
}
|
||||
pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData)
|
||||
|
||||
pkToVecIdx := make(map[int64]int)
|
||||
vecIdx := 0
|
||||
for i := range nb {
|
||||
if validData[i] {
|
||||
pkToVecIdx[int64(i)] = vecIdx
|
||||
vecIdx++
|
||||
}
|
||||
}
|
||||
scalarColumn, err := column.NewNullableColumnInt32("scalar", scalarData, scalarValidData)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
vectors := make([][]float32, validCount)
|
||||
for i := range validCount {
|
||||
@ -1783,9 +1784,9 @@ func TestNullableVectorUpsert(t *testing.T) {
|
||||
|
||||
var insertRes client.InsertResult
|
||||
if autoID {
|
||||
insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(vecColumn))
|
||||
insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(scalarColumn, vecColumn))
|
||||
} else {
|
||||
insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn))
|
||||
insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, scalarColumn, vecColumn))
|
||||
}
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, nb, insertRes.InsertCount)
|
||||
@ -1814,125 +1815,214 @@ func TestNullableVectorUpsert(t *testing.T) {
|
||||
err = loadTask.Await(ctx)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
// upsert: change first 25 rows (originally null) to valid, change rows 50-74 (originally valid) to null
|
||||
upsertNb := 50
|
||||
upsertValidData := make([]bool, upsertNb)
|
||||
for i := range upsertNb {
|
||||
upsertValidData[i] = i < 25
|
||||
}
|
||||
|
||||
upsertPkData := make([]int64, upsertNb)
|
||||
for i := range upsertNb {
|
||||
if i < 25 {
|
||||
upsertPkData[i] = actualPkData[i]
|
||||
} else {
|
||||
upsertPkData[i] = actualPkData[i+25]
|
||||
}
|
||||
}
|
||||
upsertPkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsertPkData)
|
||||
|
||||
upsertVectors := make([][]float32, 25)
|
||||
for i := range 25 {
|
||||
vec := make([]float32, common.DefaultDim)
|
||||
for j := range common.DefaultDim {
|
||||
vec[j] = float32((i+100)*common.DefaultDim+j) / 10000.0
|
||||
}
|
||||
upsertVectors[i] = vec
|
||||
}
|
||||
upsertVecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, upsertVectors, upsertValidData)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
upsertRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsertPkColumn, upsertVecColumn))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, upsertNb, upsertRes.UpsertCount)
|
||||
|
||||
var upsertedPks []int64
|
||||
if autoID {
|
||||
upsertedIDs := upsertRes.IDs.(*column.ColumnInt64)
|
||||
upsertedPks = upsertedIDs.Data()
|
||||
require.EqualValues(t, upsertNb, len(upsertedPks), "upserted PK count should match")
|
||||
} else {
|
||||
upsertedPks = upsertPkData
|
||||
}
|
||||
|
||||
// Track expected state
|
||||
expectedVectorMap := make(map[int64][]float32)
|
||||
for i := 0; i < 25; i++ {
|
||||
expectedVectorMap[upsertedPks[i]] = upsertVectors[i]
|
||||
}
|
||||
for i := 25; i < 50; i++ {
|
||||
expectedVectorMap[upsertedPks[i]] = nil
|
||||
}
|
||||
for i := 25; i < 50; i++ {
|
||||
expectedVectorMap[actualPkData[i]] = nil
|
||||
}
|
||||
for i := 75; i < 100; i++ {
|
||||
vecIdx := i - 50
|
||||
expectedVectorMap[actualPkData[i]] = vectors[vecIdx]
|
||||
expectedScalarMap := make(map[int64]int32)
|
||||
for i := range nb {
|
||||
expectedScalarMap[actualPkData[i]] = scalarData[i]
|
||||
if i < nullPercent {
|
||||
expectedVectorMap[actualPkData[i]] = vectors[i]
|
||||
} else {
|
||||
expectedVectorMap[actualPkData[i]] = nil
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Second)
|
||||
flushTask, err = mc.Flush(ctx, client.NewFlushOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
err = flushTask.Await(ctx)
|
||||
common.CheckErr(t, err, true)
|
||||
// Helper: flush, reload, search and query verify
|
||||
flushAndVerify := func(expectedValidCount int, context string) {
|
||||
time.Sleep(10 * time.Second)
|
||||
flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
err = flushTask.Await(ctx)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
err = mc.ReleaseCollection(ctx, client.NewReleaseCollectionOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
err = mc.ReleaseCollection(ctx, client.NewReleaseCollectionOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
loadTask, err = mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
err = loadTask.Await(ctx)
|
||||
common.CheckErr(t, err, true)
|
||||
loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName))
|
||||
common.CheckErr(t, err, true)
|
||||
err = loadTask.Await(ctx)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
expectedValidCount := 50
|
||||
searchVec := entity.FloatVector(common.GenFloatVector(common.DefaultDim))
|
||||
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{searchVec}).WithANNSField(common.DefaultFloatVecFieldName))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, 1, len(searchRes))
|
||||
searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data()
|
||||
require.EqualValues(t, expectedValidCount, len(searchIDs), "search should return all 50 valid vectors")
|
||||
// Search verify
|
||||
searchVec := entity.FloatVector(common.GenFloatVector(common.DefaultDim))
|
||||
searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 100, []entity.Vector{searchVec}).WithANNSField(common.DefaultFloatVecFieldName))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, 1, len(searchRes))
|
||||
require.EqualValues(t, expectedValidCount, searchRes[0].ResultCount, "%s: search should return %d valid vectors", context, expectedValidCount)
|
||||
|
||||
verifyVectorData := func(queryResult client.ResultSet, context string) {
|
||||
pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64)
|
||||
vecCol := queryResult.GetColumn(common.DefaultFloatVecFieldName).(*column.ColumnFloatVector)
|
||||
for i := 0; i < queryResult.ResultCount; i++ {
|
||||
// Query verify all rows
|
||||
queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(actualPkData))).WithOutputFields("scalar", common.DefaultFloatVecFieldName))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, nb, queryRes.ResultCount, "%s: query should return %d rows", context, nb)
|
||||
|
||||
pkCol := queryRes.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64)
|
||||
scalarCol := queryRes.GetColumn("scalar").(*column.ColumnInt32)
|
||||
vecCol := queryRes.GetColumn(common.DefaultFloatVecFieldName).(*column.ColumnFloatVector)
|
||||
|
||||
for i := 0; i < queryRes.ResultCount; i++ {
|
||||
pk, _ := pkCol.GetAsInt64(i)
|
||||
scalarVal, _ := scalarCol.GetAsInt64(i)
|
||||
isNull, _ := vecCol.IsNull(i)
|
||||
|
||||
expectedVec, exists := expectedVectorMap[pk]
|
||||
require.True(t, exists, "%s: unexpected PK %d in query results", context, pk)
|
||||
expectedScalar := expectedScalarMap[pk]
|
||||
require.EqualValues(t, expectedScalar, scalarVal, "%s: scalar mismatch for pk %d", context, pk)
|
||||
|
||||
expectedVec := expectedVectorMap[pk]
|
||||
if expectedVec != nil {
|
||||
require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk)
|
||||
vecData, _ := vecCol.Get(i)
|
||||
queriedVec := []float32(vecData.(entity.FloatVector))
|
||||
require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk)
|
||||
for j := range expectedVec {
|
||||
require.InDelta(t, expectedVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk)
|
||||
require.InDelta(t, expectedVec[j], queriedVec[j], 1e-6, "%s: vector element %d mismatch for pk %d", context, j, pk)
|
||||
}
|
||||
} else {
|
||||
require.True(t, isNull, "%s: vector should be null for pk %d", context, pk)
|
||||
vecData, _ := vecCol.Get(i)
|
||||
require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
searchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(searchIDs))).WithOutputFields(common.DefaultFloatVecFieldName))
|
||||
common.CheckErr(t, err, true)
|
||||
verifyVectorData(searchQueryRes, "All valid vectors after upsert")
|
||||
// Upsert 1: Update all 100 rows to null vectors
|
||||
upsert1PkData := make([]int64, nb)
|
||||
for i := range nb {
|
||||
upsert1PkData[i] = actualPkData[i]
|
||||
}
|
||||
upsert1PkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsert1PkData)
|
||||
|
||||
upsertedToValidPKs := upsertedPks[0:25]
|
||||
queryUpsertedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(upsertedToValidPKs))).WithOutputFields(common.DefaultFloatVecFieldName))
|
||||
allNullValidData := make([]bool, nb)
|
||||
upsert1VecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, [][]float32{}, allNullValidData)
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, 25, queryUpsertedRes.ResultCount, "should have 25 rows for upserted to valid")
|
||||
verifyVectorData(queryUpsertedRes, "Upserted valid rows")
|
||||
|
||||
countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)"))
|
||||
upsert1ScalarData := make([]int32, nb)
|
||||
upsert1ScalarValidData := make([]bool, nb)
|
||||
for i := range nb {
|
||||
upsert1ScalarData[i] = int32(i * 100)
|
||||
upsert1ScalarValidData[i] = true
|
||||
}
|
||||
upsert1ScalarColumn, err := column.NewNullableColumnInt32("scalar", upsert1ScalarData, upsert1ScalarValidData)
|
||||
common.CheckErr(t, err, true)
|
||||
totalCount, err := countRes.Fields[0].GetAsInt64(0)
|
||||
|
||||
upsertRes1, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsert1PkColumn, upsert1ScalarColumn, upsert1VecColumn))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, nb, totalCount, "total count after upsert should still be %d", nb)
|
||||
require.EqualValues(t, nb, upsertRes1.UpsertCount)
|
||||
|
||||
// For AutoID=true, upsert returns new IDs, need to update actualPkData
|
||||
if autoID {
|
||||
upsertedIDs := upsertRes1.IDs.(*column.ColumnInt64)
|
||||
newPkData := upsertedIDs.Data()
|
||||
// Clear old expected state
|
||||
expectedVectorMap = make(map[int64][]float32)
|
||||
expectedScalarMap = make(map[int64]int32)
|
||||
// Update actualPkData with new IDs
|
||||
actualPkData = newPkData
|
||||
}
|
||||
|
||||
// Update expected state: all vectors null
|
||||
for i := range nb {
|
||||
expectedVectorMap[actualPkData[i]] = nil
|
||||
expectedScalarMap[actualPkData[i]] = upsert1ScalarData[i]
|
||||
}
|
||||
|
||||
// Verify after upsert 1: search should return 0 (all null)
|
||||
flushAndVerify(0, "After Upsert1-AllNull")
|
||||
|
||||
// Upsert 2: Update rows nullPercent to nb-1 to valid vectors
|
||||
upsert2Nb := nb - nullPercent
|
||||
upsert2PkData := make([]int64, upsert2Nb)
|
||||
for i := range upsert2Nb {
|
||||
upsert2PkData[i] = actualPkData[i+nullPercent]
|
||||
}
|
||||
upsert2PkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsert2PkData)
|
||||
|
||||
upsert2Vectors := make([][]float32, upsert2Nb)
|
||||
upsert2ValidData := make([]bool, upsert2Nb)
|
||||
for i := range upsert2Nb {
|
||||
vec := make([]float32, common.DefaultDim)
|
||||
for j := range common.DefaultDim {
|
||||
vec[j] = float32((i+500)*common.DefaultDim+j) / 10000.0
|
||||
}
|
||||
upsert2Vectors[i] = vec
|
||||
upsert2ValidData[i] = true
|
||||
}
|
||||
upsert2VecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, upsert2Vectors, upsert2ValidData)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
upsert2ScalarData := make([]int32, upsert2Nb)
|
||||
upsert2ScalarValidData := make([]bool, upsert2Nb)
|
||||
for i := range upsert2Nb {
|
||||
upsert2ScalarData[i] = int32((i + nullPercent) * 200)
|
||||
upsert2ScalarValidData[i] = true
|
||||
}
|
||||
upsert2ScalarColumn, err := column.NewNullableColumnInt32("scalar", upsert2ScalarData, upsert2ScalarValidData)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
upsertRes2, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsert2PkColumn, upsert2ScalarColumn, upsert2VecColumn))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, upsert2Nb, upsertRes2.UpsertCount)
|
||||
|
||||
// For AutoID=true, upsert returns new IDs for the upserted rows
|
||||
if autoID {
|
||||
upsertedIDs := upsertRes2.IDs.(*column.ColumnInt64)
|
||||
newPkData := upsertedIDs.Data()
|
||||
// Update actualPkData for rows nullPercent to nb-1
|
||||
for i := range upsert2Nb {
|
||||
// Remove old expected state
|
||||
delete(expectedVectorMap, actualPkData[i+nullPercent])
|
||||
delete(expectedScalarMap, actualPkData[i+nullPercent])
|
||||
// Update to new PK
|
||||
actualPkData[i+nullPercent] = newPkData[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Update expected state: rows nullPercent to nb-1 now have valid vectors
|
||||
for i := range upsert2Nb {
|
||||
expectedVectorMap[actualPkData[i+nullPercent]] = upsert2Vectors[i]
|
||||
expectedScalarMap[actualPkData[i+nullPercent]] = upsert2ScalarData[i]
|
||||
}
|
||||
|
||||
// Verify after upsert 2: search should return upsert2Nb (rows nullPercent to nb-1 valid)
|
||||
flushAndVerify(upsert2Nb, "After Upsert2-NullToValid")
|
||||
|
||||
// Upsert 3: Partial update rows 0 to nullPercent-1 (only scalar), vector preserved (still null)
|
||||
upsert3Nb := nullPercent
|
||||
upsert3PkData := make([]int64, upsert3Nb)
|
||||
upsert3ScalarData := make([]int32, upsert3Nb)
|
||||
upsert3ScalarValidData := make([]bool, upsert3Nb)
|
||||
for i := range upsert3Nb {
|
||||
upsert3PkData[i] = actualPkData[i]
|
||||
upsert3ScalarData[i] = int32(i * 1000)
|
||||
upsert3ScalarValidData[i] = true
|
||||
}
|
||||
upsert3PkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsert3PkData)
|
||||
upsert3ScalarColumn, err := column.NewNullableColumnInt32("scalar", upsert3ScalarData, upsert3ScalarValidData)
|
||||
common.CheckErr(t, err, true)
|
||||
|
||||
upsertRes3, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsert3PkColumn, upsert3ScalarColumn).WithPartialUpdate(true))
|
||||
common.CheckErr(t, err, true)
|
||||
require.EqualValues(t, upsert3Nb, upsertRes3.UpsertCount)
|
||||
|
||||
// For AutoID=true, upsert returns new IDs for the upserted rows
|
||||
if autoID {
|
||||
upsertedIDs := upsertRes3.IDs.(*column.ColumnInt64)
|
||||
newPkData := upsertedIDs.Data()
|
||||
// Update actualPkData for rows 0 to nullPercent-1
|
||||
for i := range upsert3Nb {
|
||||
// Remove old expected state
|
||||
delete(expectedVectorMap, actualPkData[i])
|
||||
delete(expectedScalarMap, actualPkData[i])
|
||||
// Update to new PK
|
||||
actualPkData[i] = newPkData[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Update expected state: rows 0 to nullPercent-1 scalar updated, vector preserved (null)
|
||||
for i := range upsert3Nb {
|
||||
expectedScalarMap[actualPkData[i]] = upsert3ScalarData[i]
|
||||
// Vector remains null (preserved from before)
|
||||
expectedVectorMap[actualPkData[i]] = nil
|
||||
}
|
||||
|
||||
// Verify after upsert 3: search should still return upsert2Nb (only rows nullPercent to nb-1 valid)
|
||||
flushAndVerify(upsert2Nb, "After Upsert3-PartialUpdate")
|
||||
|
||||
// clean up
|
||||
err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user