From 4dad649549f18ce8fc684401adf99f01eee75c05 Mon Sep 17 00:00:00 2001 From: marcelo-cjl Date: Thu, 15 Jan 2026 19:07:28 +0800 Subject: [PATCH] fix: partial upsert with nullable vector fields (#46936) issue: #46849 related: #45993 Nullable vector fields use compressed storage where null rows don't store actual vector data. This causes issues during partial upsert when merging existing data with upsert data. Changes: - Add nullableVectorMergeContext to track row-to-data index mappings - Add buildNullableVectorIdxMap to build index mapping from ValidData - Add rebuildNullableVectorFieldData to merge vectors in compressed format - Add getVectorDataAtIndex/appendVectorData helpers for all vector types - Fix checkAligned to handle all-null nullable vector fields (Vectors=nil) - Add comprehensive unit tests for new functions - Add e2e tests covering 3 upsert scenarios: - Upsert all rows to null vectors - Upsert null rows to valid vectors - Partial upsert (scalar only, vector preserved) Signed-off-by: marcelo-cjl --- internal/proxy/task_upsert.go | 246 +++++++++++- internal/proxy/task_upsert_test.go | 380 ++++++++++++++++++ internal/proxy/validate_util.go | 55 ++- internal/proxy/validate_util_test.go | 286 +++++++++++++ .../testcases/nullable_default_value_test.go | 302 +++++++++----- 5 files changed, 1150 insertions(+), 119 deletions(-) diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 8d9c42e8f4..9e73ab321e 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -365,9 +365,10 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { it.insertFieldData = typeutil.PrepareResultFieldData(existFieldData, int64(upsertIDSize)) if len(updateIdxInUpsert) > 0 { - // Note: For fields containing default values, default values need to be set according to valid data during insertion, - // but query results fields do not set valid data when returning default value fields, - // therefore valid data needs to be manually set to true + upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (int64, *schemapb.FieldData) { + return field.FieldId, field + }) + nullableVectorContexts := make(map[int64]*nullableVectorMergeContext) for _, fieldData := range existFieldData { fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName()) if err != nil { @@ -375,10 +376,29 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { return err } - if fieldSchema.GetDefaultValue() != nil { - fieldData.ValidData = make([]bool, upsertIDSize) - for i := range fieldData.ValidData { - fieldData.ValidData[i] = true + if fieldSchema.GetNullable() && typeutil.IsVectorType(fieldSchema.GetDataType()) { + ctx := &nullableVectorMergeContext{ + existIdxMap: buildNullableVectorIdxMap(fieldData.GetValidData()), + existField: fieldData, + mergedValid: make([]bool, 0, len(updateIdxInUpsert)), + } + + if upsertField, ok := upsertFieldMap[fieldData.FieldId]; ok { + ctx.upsertIdxMap = buildNullableVectorIdxMap(upsertField.GetValidData()) + ctx.upsertField = upsertField + ctx.hasUpsertData = true + } + + nullableVectorContexts[fieldData.FieldId] = ctx + } else if fieldSchema.GetDefaultValue() != nil { + // Note: For fields containing default values, default values need to be set according to valid data during insertion, + // but query results fields do not set valid data when returning default value fields, + // therefore valid data needs to be manually set to true + if len(fieldData.GetValidData()) == 0 { + fieldData.ValidData = make([]bool, upsertIDSize) + for i := range fieldData.ValidData { + fieldData.ValidData[i] = true + } } } } @@ -409,6 +429,28 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { log.Info("update field data failed", zap.Error(err)) return err } + + // Collect merged validity for nullable vector fields + for _, ctx := range nullableVectorContexts { + if ctx.hasUpsertData { + ctx.mergedValid = append(ctx.mergedValid, ctx.upsertField.GetValidData()[idx]) + } else { + ctx.mergedValid = append(ctx.mergedValid, ctx.existField.GetValidData()[existIndex]) + } + } + } + + // Rebuild nullable vector fields + for i, fieldData := range it.insertFieldData { + ctx, ok := nullableVectorContexts[fieldData.FieldId] + if !ok { + continue + } + + newFieldData := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex) + if newFieldData != nil { + it.insertFieldData[i] = newFieldData + } } } @@ -482,6 +524,196 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error { return nil } +func buildNullableVectorIdxMap(validData []bool) []int64 { + if len(validData) == 0 { + return nil + } + idxMap := make([]int64, len(validData)) + dataIdx := int64(0) + for i, valid := range validData { + if valid { + idxMap[i] = dataIdx + dataIdx++ + } else { + idxMap[i] = -1 + } + } + return idxMap +} + +type nullableVectorMergeContext struct { + existIdxMap []int64 // row index -> data index for exist data + upsertIdxMap []int64 // row index -> data index for upsert data + mergedValid []bool // merged validity for the output + existField *schemapb.FieldData + upsertField *schemapb.FieldData + hasUpsertData bool // whether upsert request contains this field +} + +func rebuildNullableVectorFieldData( + ctx *nullableVectorMergeContext, + updateIdxInUpsert []int, + upsertIDs *schemapb.IDs, + existPKToIndex map[interface{}]int, +) *schemapb.FieldData { + var sourceField *schemapb.FieldData + var sourceIdxMap []int64 + + if ctx.hasUpsertData { + sourceField = ctx.upsertField + sourceIdxMap = ctx.upsertIdxMap + } else { + sourceField = ctx.existField + sourceIdxMap = ctx.existIdxMap + } + + if sourceField == nil { + return nil + } + + newFieldData := prepareNullableVectorFieldData(sourceField, int64(len(updateIdxInUpsert))) + newFieldData.ValidData = ctx.mergedValid + + // Append vectors from source using the correct indices + for rowIdx, idx := range updateIdxInUpsert { + var sourceRowIdx int + if ctx.hasUpsertData { + sourceRowIdx = idx // upsertMsg uses updateIdxInUpsert + } else { + // existFieldData uses existIndex + oldPK := typeutil.GetPK(upsertIDs, int64(idx)) + sourceRowIdx = existPKToIndex[oldPK] + } + + // Check if this row is valid (has vector data) + if rowIdx < len(ctx.mergedValid) && ctx.mergedValid[rowIdx] { + dataIdx := int64(sourceRowIdx) + if len(sourceIdxMap) > 0 && sourceRowIdx < len(sourceIdxMap) { + dataIdx = sourceIdxMap[sourceRowIdx] + } + if dataIdx >= 0 { + appendSingleVector(newFieldData, sourceField, dataIdx) + } + } + } + + return newFieldData +} + +func prepareNullableVectorFieldData(sample *schemapb.FieldData, capacity int64) *schemapb.FieldData { + fd := &schemapb.FieldData{ + Type: sample.Type, + FieldName: sample.FieldName, + FieldId: sample.FieldId, + IsDynamic: sample.IsDynamic, + ValidData: make([]bool, 0, capacity), + } + + vectorField := sample.GetVectors() + if vectorField == nil { + return fd + } + dim := vectorField.GetDim() + vectors := &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + }, + } + + switch vectorField.Data.(type) { + case *schemapb.VectorField_FloatVector: + vectors.Vectors.Data = &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: make([]float32, 0, dim*capacity), + }, + } + case *schemapb.VectorField_Float16Vector: + vectors.Vectors.Data = &schemapb.VectorField_Float16Vector{ + Float16Vector: make([]byte, 0, dim*2*capacity), + } + case *schemapb.VectorField_Bfloat16Vector: + vectors.Vectors.Data = &schemapb.VectorField_Bfloat16Vector{ + Bfloat16Vector: make([]byte, 0, dim*2*capacity), + } + case *schemapb.VectorField_BinaryVector: + vectors.Vectors.Data = &schemapb.VectorField_BinaryVector{ + BinaryVector: make([]byte, 0, dim/8*capacity), + } + case *schemapb.VectorField_Int8Vector: + vectors.Vectors.Data = &schemapb.VectorField_Int8Vector{ + Int8Vector: make([]byte, 0, dim*capacity), + } + case *schemapb.VectorField_SparseFloatVector: + vectors.Vectors.Data = &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: &schemapb.SparseFloatArray{ + Contents: make([][]byte, 0, capacity), + Dim: vectorField.GetSparseFloatVector().GetDim(), + }, + } + } + fd.Field = vectors + return fd +} + +func appendSingleVector(target *schemapb.FieldData, source *schemapb.FieldData, dataIdx int64) { + targetVectors := target.GetVectors() + sourceVectors := source.GetVectors() + if targetVectors == nil || sourceVectors == nil { + return + } + dim := sourceVectors.GetDim() + + switch sv := sourceVectors.Data.(type) { + case *schemapb.VectorField_FloatVector: + tv := targetVectors.Data.(*schemapb.VectorField_FloatVector) + start := dataIdx * dim + end := start + dim + if end <= int64(len(sv.FloatVector.Data)) { + tv.FloatVector.Data = append(tv.FloatVector.Data, sv.FloatVector.Data[start:end]...) + } + case *schemapb.VectorField_Float16Vector: + tv := targetVectors.Data.(*schemapb.VectorField_Float16Vector) + unitSize := dim * 2 + start := dataIdx * unitSize + end := start + unitSize + if end <= int64(len(sv.Float16Vector)) { + tv.Float16Vector = append(tv.Float16Vector, sv.Float16Vector[start:end]...) + } + case *schemapb.VectorField_Bfloat16Vector: + tv := targetVectors.Data.(*schemapb.VectorField_Bfloat16Vector) + unitSize := dim * 2 + start := dataIdx * unitSize + end := start + unitSize + if end <= int64(len(sv.Bfloat16Vector)) { + tv.Bfloat16Vector = append(tv.Bfloat16Vector, sv.Bfloat16Vector[start:end]...) + } + case *schemapb.VectorField_BinaryVector: + tv := targetVectors.Data.(*schemapb.VectorField_BinaryVector) + unitSize := dim / 8 + start := dataIdx * unitSize + end := start + unitSize + if end <= int64(len(sv.BinaryVector)) { + tv.BinaryVector = append(tv.BinaryVector, sv.BinaryVector[start:end]...) + } + case *schemapb.VectorField_Int8Vector: + tv := targetVectors.Data.(*schemapb.VectorField_Int8Vector) + start := dataIdx * dim + end := start + dim + if end <= int64(len(sv.Int8Vector)) { + tv.Int8Vector = append(tv.Int8Vector, sv.Int8Vector[start:end]...) + } + case *schemapb.VectorField_SparseFloatVector: + tv := targetVectors.Data.(*schemapb.VectorField_SparseFloatVector) + if dataIdx < int64(len(sv.SparseFloatVector.Contents)) { + tv.SparseFloatVector.Contents = append(tv.SparseFloatVector.Contents, sv.SparseFloatVector.Contents[dataIdx]) + // Update dimension if necessary + if sv.SparseFloatVector.Dim > tv.SparseFloatVector.Dim { + tv.SparseFloatVector.Dim = sv.SparseFloatVector.Dim + } + } + } +} + // ToCompressedFormatNullable converts the field data from full format nullable to compressed format nullable func ToCompressedFormatNullable(field *schemapb.FieldData) error { if getValidNumber(field.GetValidData()) == len(field.GetValidData()) { diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index da64ecdf13..0a523ea5f3 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -1825,3 +1825,383 @@ func TestUpsertTask_queryPreExecute_EmptyDataArray(t *testing.T) { }) }) } + +func TestBuildNullableVectorIdxMap(t *testing.T) { + t.Run("empty validData", func(t *testing.T) { + result := buildNullableVectorIdxMap(nil) + assert.Nil(t, result) + + result = buildNullableVectorIdxMap([]bool{}) + assert.Nil(t, result) + }) + + t.Run("all valid", func(t *testing.T) { + validData := []bool{true, true, true, true} + result := buildNullableVectorIdxMap(validData) + assert.Equal(t, []int64{0, 1, 2, 3}, result) + }) + + t.Run("all null", func(t *testing.T) { + validData := []bool{false, false, false, false} + result := buildNullableVectorIdxMap(validData) + assert.Equal(t, []int64{-1, -1, -1, -1}, result) + }) + + t.Run("mixed valid and null", func(t *testing.T) { + validData := []bool{true, false, true, false, true} + result := buildNullableVectorIdxMap(validData) + // dataIdx: 0 for row 0, -1 for row 1, 1 for row 2, -1 for row 3, 2 for row 4 + assert.Equal(t, []int64{0, -1, 1, -1, 2}, result) + }) + + t.Run("alternating pattern", func(t *testing.T) { + validData := []bool{false, true, false, true, false, true} + result := buildNullableVectorIdxMap(validData) + assert.Equal(t, []int64{-1, 0, -1, 1, -1, 2}, result) + }) +} + +func TestPrepareNullableVectorFieldData(t *testing.T) { + dim := int64(4) + + t.Run("FloatVector", func(t *testing.T) { + sample := &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: "float_vec", + FieldId: 100, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4}}}, + }, + }, + } + result := prepareNullableVectorFieldData(sample, 10) + assert.Equal(t, schemapb.DataType_FloatVector, result.Type) + assert.Equal(t, "float_vec", result.FieldName) + assert.Equal(t, int64(100), result.FieldId) + assert.Equal(t, dim, result.GetVectors().GetDim()) + assert.Empty(t, result.GetVectors().GetFloatVector().Data) + }) + + t.Run("Float16Vector", func(t *testing.T) { + sample := &schemapb.FieldData{ + Type: schemapb.DataType_Float16Vector, + FieldName: "fp16_vec", + FieldId: 101, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Float16Vector{Float16Vector: []byte{1, 2, 3, 4, 5, 6, 7, 8}}, + }, + }, + } + result := prepareNullableVectorFieldData(sample, 10) + assert.Equal(t, schemapb.DataType_Float16Vector, result.Type) + assert.Equal(t, dim, result.GetVectors().GetDim()) + assert.Empty(t, result.GetVectors().GetFloat16Vector()) + }) + + t.Run("BFloat16Vector", func(t *testing.T) { + sample := &schemapb.FieldData{ + Type: schemapb.DataType_BFloat16Vector, + FieldName: "bf16_vec", + FieldId: 102, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Bfloat16Vector{Bfloat16Vector: []byte{1, 2, 3, 4, 5, 6, 7, 8}}, + }, + }, + } + result := prepareNullableVectorFieldData(sample, 10) + assert.Equal(t, schemapb.DataType_BFloat16Vector, result.Type) + assert.Equal(t, dim, result.GetVectors().GetDim()) + assert.Empty(t, result.GetVectors().GetBfloat16Vector()) + }) + + t.Run("BinaryVector", func(t *testing.T) { + sample := &schemapb.FieldData{ + Type: schemapb.DataType_BinaryVector, + FieldName: "binary_vec", + FieldId: 103, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(32), // binary vector dim is in bits + Data: &schemapb.VectorField_BinaryVector{BinaryVector: []byte{1, 2, 3, 4}}, + }, + }, + } + result := prepareNullableVectorFieldData(sample, 10) + assert.Equal(t, schemapb.DataType_BinaryVector, result.Type) + assert.Equal(t, int64(32), result.GetVectors().GetDim()) + assert.Empty(t, result.GetVectors().GetBinaryVector()) + }) + + t.Run("SparseFloatVector", func(t *testing.T) { + sample := &schemapb.FieldData{ + Type: schemapb.DataType_SparseFloatVector, + FieldName: "sparse_vec", + FieldId: 104, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 100, + Data: &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: &schemapb.SparseFloatArray{ + Dim: 100, + Contents: [][]byte{{1, 2, 3}}, + }, + }, + }, + }, + } + result := prepareNullableVectorFieldData(sample, 10) + assert.Equal(t, schemapb.DataType_SparseFloatVector, result.Type) + assert.Empty(t, result.GetVectors().GetSparseFloatVector().Contents) + }) + + t.Run("nil vectors", func(t *testing.T) { + sample := &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: "empty_vec", + FieldId: 105, + } + result := prepareNullableVectorFieldData(sample, 10) + assert.Equal(t, schemapb.DataType_FloatVector, result.Type) + assert.Nil(t, result.GetVectors()) + }) +} + +func TestAppendSingleVector(t *testing.T) { + t.Run("FloatVector", func(t *testing.T) { + dim := int64(4) + source := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}}, // 2 vectors + }, + }, + }, + } + target := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{}}}, + }, + }, + } + + appendSingleVector(target, source, 0) + assert.Equal(t, []float32{1, 2, 3, 4}, target.GetVectors().GetFloatVector().Data) + + appendSingleVector(target, source, 1) + assert.Equal(t, []float32{1, 2, 3, 4, 5, 6, 7, 8}, target.GetVectors().GetFloatVector().Data) + }) + + t.Run("Float16Vector", func(t *testing.T) { + dim := int64(2) + source := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Float16Vector{Float16Vector: []byte{1, 2, 3, 4, 5, 6, 7, 8}}, // 2 vectors (dim*2 bytes each) + }, + }, + } + target := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_Float16Vector{Float16Vector: []byte{}}, + }, + }, + } + + appendSingleVector(target, source, 0) + assert.Equal(t, []byte{1, 2, 3, 4}, target.GetVectors().GetFloat16Vector()) + + appendSingleVector(target, source, 1) + assert.Equal(t, []byte{1, 2, 3, 4, 5, 6, 7, 8}, target.GetVectors().GetFloat16Vector()) + }) + + t.Run("BinaryVector", func(t *testing.T) { + dim := int64(16) // 16 bits = 2 bytes per vector + source := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_BinaryVector{BinaryVector: []byte{1, 2, 3, 4}}, // 2 vectors + }, + }, + } + target := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_BinaryVector{BinaryVector: []byte{}}, + }, + }, + } + + appendSingleVector(target, source, 0) + assert.Equal(t, []byte{1, 2}, target.GetVectors().GetBinaryVector()) + + appendSingleVector(target, source, 1) + assert.Equal(t, []byte{1, 2, 3, 4}, target.GetVectors().GetBinaryVector()) + }) + + t.Run("SparseFloatVector", func(t *testing.T) { + source := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 100, + Data: &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: &schemapb.SparseFloatArray{ + Dim: 100, + Contents: [][]byte{{1, 2}, {3, 4}}, + }, + }, + }, + }, + } + target := &schemapb.FieldData{ + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 100, + Data: &schemapb.VectorField_SparseFloatVector{ + SparseFloatVector: &schemapb.SparseFloatArray{ + Dim: 0, + Contents: [][]byte{}, + }, + }, + }, + }, + } + + appendSingleVector(target, source, 0) + assert.Equal(t, [][]byte{{1, 2}}, target.GetVectors().GetSparseFloatVector().Contents) + + appendSingleVector(target, source, 1) + assert.Equal(t, [][]byte{{1, 2}, {3, 4}}, target.GetVectors().GetSparseFloatVector().Contents) + }) + + t.Run("nil vectors", func(t *testing.T) { + source := &schemapb.FieldData{} + target := &schemapb.FieldData{} + // Should not panic + appendSingleVector(target, source, 0) + }) +} + +func TestRebuildNullableVectorFieldData(t *testing.T) { + dim := int64(4) + + t.Run("with upsert data - all valid", func(t *testing.T) { + upsertField := &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: "vec", + FieldId: 100, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}}, + }, + }, + ValidData: []bool{true, true, true}, + } + + ctx := &nullableVectorMergeContext{ + upsertIdxMap: []int64{0, 1, 2}, + upsertField: upsertField, + hasUpsertData: true, + mergedValid: []bool{true, true}, + } + + updateIdxInUpsert := []int{0, 2} + upsertIDs := &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}}}} + existPKToIndex := map[interface{}]int{} + + result := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex) + assert.NotNil(t, result) + assert.Equal(t, []bool{true, true}, result.ValidData) + // Should have 2 vectors: index 0 and index 2 from source + assert.Equal(t, []float32{1, 2, 3, 4, 9, 10, 11, 12}, result.GetVectors().GetFloatVector().Data) + }) + + t.Run("with upsert data - mixed valid/null", func(t *testing.T) { + upsertField := &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: "vec", + FieldId: 100, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}}}, // only 2 vectors (compressed) + }, + }, + ValidData: []bool{true, false, true}, // row 0 valid, row 1 null, row 2 valid + } + + ctx := &nullableVectorMergeContext{ + upsertIdxMap: []int64{0, -1, 1}, // row 0 -> data 0, row 1 -> null, row 2 -> data 1 + upsertField: upsertField, + hasUpsertData: true, + mergedValid: []bool{true, false}, // result: row 0 valid, row 1 null + } + + updateIdxInUpsert := []int{0, 1} + upsertIDs := &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}}}} + existPKToIndex := map[interface{}]int{} + + result := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex) + assert.NotNil(t, result) + assert.Equal(t, []bool{true, false}, result.ValidData) + // Should have only 1 vector (row 0), row 1 is null + assert.Equal(t, []float32{1, 2, 3, 4}, result.GetVectors().GetFloatVector().Data) + }) + + t.Run("without upsert data - use exist data", func(t *testing.T) { + existField := &schemapb.FieldData{ + Type: schemapb.DataType_FloatVector, + FieldName: "vec", + FieldId: 100, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, + Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}}}, + }, + }, + ValidData: []bool{true, true}, + } + + ctx := &nullableVectorMergeContext{ + existIdxMap: []int64{0, 1}, + existField: existField, + hasUpsertData: false, + mergedValid: []bool{true, true}, + } + + updateIdxInUpsert := []int{0, 1} + upsertIDs := &schemapb.IDs{IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{10, 20}}}} + existPKToIndex := map[interface{}]int{int64(10): 0, int64(20): 1} + + result := rebuildNullableVectorFieldData(ctx, updateIdxInUpsert, upsertIDs, existPKToIndex) + assert.NotNil(t, result) + assert.Equal(t, []bool{true, true}, result.ValidData) + assert.Equal(t, []float32{1, 2, 3, 4, 5, 6, 7, 8}, result.GetVectors().GetFloatVector().Data) + }) + + t.Run("nil source field", func(t *testing.T) { + ctx := &nullableVectorMergeContext{ + hasUpsertData: false, + existField: nil, + mergedValid: []bool{true}, + } + + result := rebuildNullableVectorFieldData(ctx, []int{0}, nil, nil) + assert.Nil(t, result) + }) +} diff --git a/internal/proxy/validate_util.go b/internal/proxy/validate_util.go index bea97011ff..7935480aee 100644 --- a/internal/proxy/validate_util.go +++ b/internal/proxy/validate_util.go @@ -226,6 +226,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } + expectedRows := getExpectedVectorRows(field, f) + if field.GetVectors() == nil { + if expectedRows != 0 { + return errNumRowsMismatch(field.GetFieldName(), 0) + } + continue + } + dim, err := typeutil.GetDim(f) if err != nil { return err @@ -240,7 +248,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return errDimMismatch(field.GetFieldName(), dataDim, dim) } - expectedRows := getExpectedVectorRows(field, f) if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -251,6 +258,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } + expectedRows := getExpectedVectorRows(field, f) + if field.GetVectors() == nil { + if expectedRows != 0 { + return errNumRowsMismatch(field.GetFieldName(), 0) + } + continue + } + dim, err := typeutil.GetDim(f) if err != nil { return err @@ -265,7 +280,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } - expectedRows := getExpectedVectorRows(field, f) if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -276,6 +290,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } + expectedRows := getExpectedVectorRows(field, f) + if field.GetVectors() == nil { + if expectedRows != 0 { + return errNumRowsMismatch(field.GetFieldName(), 0) + } + continue + } + dim, err := typeutil.GetDim(f) if err != nil { return err @@ -290,7 +312,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } - expectedRows := getExpectedVectorRows(field, f) if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -301,6 +322,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } + expectedRows := getExpectedVectorRows(field, f) + if field.GetVectors() == nil { + if expectedRows != 0 { + return errNumRowsMismatch(field.GetFieldName(), 0) + } + continue + } + dim, err := typeutil.GetDim(f) if err != nil { return err @@ -315,7 +344,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } - expectedRows := getExpectedVectorRows(field, f) if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -325,8 +353,16 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil if err != nil { return err } - n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents)) + expectedRows := getExpectedVectorRows(field, f) + if field.GetVectors() == nil || field.GetVectors().GetSparseFloatVector() == nil { + if expectedRows != 0 { + return errNumRowsMismatch(field.GetFieldName(), 0) + } + continue + } + + n := uint64(len(field.GetVectors().GetSparseFloatVector().Contents)) if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } @@ -337,6 +373,14 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return err } + expectedRows := getExpectedVectorRows(field, f) + if field.GetVectors() == nil { + if expectedRows != 0 { + return errNumRowsMismatch(field.GetFieldName(), 0) + } + continue + } + dim, err := typeutil.GetDim(f) if err != nil { return err @@ -351,7 +395,6 @@ func (v *validateUtil) checkAligned(data []*schemapb.FieldData, schema *typeutil return errDimMismatch(field.GetFieldName(), dataDim, dim) } - expectedRows := getExpectedVectorRows(field, f) if n != expectedRows { return errNumRowsMismatch(field.GetFieldName(), n) } diff --git a/internal/proxy/validate_util_test.go b/internal/proxy/validate_util_test.go index 4d015926d4..f352b39798 100644 --- a/internal/proxy/validate_util_test.go +++ b/internal/proxy/validate_util_test.go @@ -1998,6 +1998,292 @@ func Test_validateUtil_checkAligned(t *testing.T) { assert.NoError(t, err) }) + + t.Run("nullable float vector all null", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_FloatVector, + ValidData: []bool{false, false, false}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 3) + + assert.NoError(t, err) + }) + + t.Run("nullable float vector partial null", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_FloatVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, // 2 vectors of dim 8 + }, + }, + Dim: 8, + }, + }, + ValidData: []bool{true, false, true, false}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 4) + + assert.NoError(t, err) + }) + + t.Run("nullable float vector mismatch valid count", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_FloatVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: []float32{1, 2, 3, 4, 5, 6, 7, 8}, // only 1 vector + }, + }, + Dim: 8, + }, + }, + ValidData: []bool{true, false, true, false}, // 2 valid + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 4) + + assert.Error(t, err) + }) + + t.Run("nullable binary vector all null", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_BinaryVector, + ValidData: []bool{false, false, false}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_BinaryVector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 3) + + assert.NoError(t, err) + }) + + t.Run("nullable float16 vector all null", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Float16Vector, + ValidData: []bool{false, false, false}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Float16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 3) + + assert.NoError(t, err) + }) + + t.Run("nullable bfloat16 vector all null", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_BFloat16Vector, + ValidData: []bool{false, false, false}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_BFloat16Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 3) + + assert.NoError(t, err) + }) + + t.Run("nullable int8 vector all null", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_Int8Vector, + ValidData: []bool{false, false, false}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_Int8Vector, + TypeParams: []*commonpb.KeyValuePair{ + { + Key: common.DimKey, + Value: "8", + }, + }, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 3) + + assert.NoError(t, err) + }) + + t.Run("nullable sparse vector all null", func(t *testing.T) { + data := []*schemapb.FieldData{ + { + FieldName: "test", + Type: schemapb.DataType_SparseFloatVector, + ValidData: []bool{false, false, false}, + }, + } + + schema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "test", + DataType: schemapb.DataType_SparseFloatVector, + Nullable: true, + }, + }, + } + h, err := typeutil.CreateSchemaHelper(schema) + assert.NoError(t, err) + + v := newValidateUtil() + + err = v.checkAligned(data, h, 3) + + assert.NoError(t, err) + }) } func Test_validateUtil_Validate(t *testing.T) { diff --git a/tests/go_client/testcases/nullable_default_value_test.go b/tests/go_client/testcases/nullable_default_value_test.go index 9abca78f73..d4bc8b2170 100644 --- a/tests/go_client/testcases/nullable_default_value_test.go +++ b/tests/go_client/testcases/nullable_default_value_test.go @@ -1732,43 +1732,44 @@ func TestNullableVectorUpsert(t *testing.T) { ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) mc := hp.CreateDefaultMilvusClient(ctx, t) - collName := common.GenRandomString("nullable_vec_ups", 5) + // Create collection with pk, scalar, and nullable vector fields + collName := common.GenRandomString("nullable_vec_upsert", 5) pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) if autoID { pkField.AutoID = true } + scalarField := entity.NewField().WithName("scalar").WithDataType(entity.FieldTypeInt32).WithNullable(true) vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim).WithNullable(true) - schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(vecField) + schema := entity.NewSchema().WithName(collName).WithField(pkField).WithField(scalarField).WithField(vecField) err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema).WithConsistencyLevel(entity.ClStrong)) common.CheckErr(t, err, true) - // insert initial data with 50% null + // Insert initial data: 100 rows + // Rows 0 to nullPercent-1: valid vector, scalar = i*10 + // Rows nullPercent to nb-1: null vector, scalar = i*10 nb := 100 nullPercent := 50 validData := make([]bool, nb) validCount := 0 for i := range nb { - validData[i] = (i % 100) >= nullPercent + validData[i] = i < nullPercent if validData[i] { validCount++ } } pkData := make([]int64, nb) + scalarData := make([]int32, nb) + scalarValidData := make([]bool, nb) for i := range nb { pkData[i] = int64(i) + scalarData[i] = int32(i * 10) + scalarValidData[i] = true } pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, pkData) - - pkToVecIdx := make(map[int64]int) - vecIdx := 0 - for i := range nb { - if validData[i] { - pkToVecIdx[int64(i)] = vecIdx - vecIdx++ - } - } + scalarColumn, err := column.NewNullableColumnInt32("scalar", scalarData, scalarValidData) + common.CheckErr(t, err, true) vectors := make([][]float32, validCount) for i := range validCount { @@ -1783,9 +1784,9 @@ func TestNullableVectorUpsert(t *testing.T) { var insertRes client.InsertResult if autoID { - insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(vecColumn)) + insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(scalarColumn, vecColumn)) } else { - insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn)) + insertRes, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName, pkColumn, scalarColumn, vecColumn)) } common.CheckErr(t, err, true) require.EqualValues(t, nb, insertRes.InsertCount) @@ -1814,125 +1815,214 @@ func TestNullableVectorUpsert(t *testing.T) { err = loadTask.Await(ctx) common.CheckErr(t, err, true) - // upsert: change first 25 rows (originally null) to valid, change rows 50-74 (originally valid) to null - upsertNb := 50 - upsertValidData := make([]bool, upsertNb) - for i := range upsertNb { - upsertValidData[i] = i < 25 - } - - upsertPkData := make([]int64, upsertNb) - for i := range upsertNb { - if i < 25 { - upsertPkData[i] = actualPkData[i] - } else { - upsertPkData[i] = actualPkData[i+25] - } - } - upsertPkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsertPkData) - - upsertVectors := make([][]float32, 25) - for i := range 25 { - vec := make([]float32, common.DefaultDim) - for j := range common.DefaultDim { - vec[j] = float32((i+100)*common.DefaultDim+j) / 10000.0 - } - upsertVectors[i] = vec - } - upsertVecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, upsertVectors, upsertValidData) - common.CheckErr(t, err, true) - - upsertRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsertPkColumn, upsertVecColumn)) - common.CheckErr(t, err, true) - require.EqualValues(t, upsertNb, upsertRes.UpsertCount) - - var upsertedPks []int64 - if autoID { - upsertedIDs := upsertRes.IDs.(*column.ColumnInt64) - upsertedPks = upsertedIDs.Data() - require.EqualValues(t, upsertNb, len(upsertedPks), "upserted PK count should match") - } else { - upsertedPks = upsertPkData - } - + // Track expected state expectedVectorMap := make(map[int64][]float32) - for i := 0; i < 25; i++ { - expectedVectorMap[upsertedPks[i]] = upsertVectors[i] - } - for i := 25; i < 50; i++ { - expectedVectorMap[upsertedPks[i]] = nil - } - for i := 25; i < 50; i++ { - expectedVectorMap[actualPkData[i]] = nil - } - for i := 75; i < 100; i++ { - vecIdx := i - 50 - expectedVectorMap[actualPkData[i]] = vectors[vecIdx] + expectedScalarMap := make(map[int64]int32) + for i := range nb { + expectedScalarMap[actualPkData[i]] = scalarData[i] + if i < nullPercent { + expectedVectorMap[actualPkData[i]] = vectors[i] + } else { + expectedVectorMap[actualPkData[i]] = nil + } } - time.Sleep(10 * time.Second) - flushTask, err = mc.Flush(ctx, client.NewFlushOption(collName)) - common.CheckErr(t, err, true) - err = flushTask.Await(ctx) - common.CheckErr(t, err, true) + // Helper: flush, reload, search and query verify + flushAndVerify := func(expectedValidCount int, context string) { + time.Sleep(10 * time.Second) + flushTask, err := mc.Flush(ctx, client.NewFlushOption(collName)) + common.CheckErr(t, err, true) + err = flushTask.Await(ctx) + common.CheckErr(t, err, true) - err = mc.ReleaseCollection(ctx, client.NewReleaseCollectionOption(collName)) - common.CheckErr(t, err, true) + err = mc.ReleaseCollection(ctx, client.NewReleaseCollectionOption(collName)) + common.CheckErr(t, err, true) - loadTask, err = mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) - common.CheckErr(t, err, true) - err = loadTask.Await(ctx) - common.CheckErr(t, err, true) + loadTask, err := mc.LoadCollection(ctx, client.NewLoadCollectionOption(collName)) + common.CheckErr(t, err, true) + err = loadTask.Await(ctx) + common.CheckErr(t, err, true) - expectedValidCount := 50 - searchVec := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) - searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 50, []entity.Vector{searchVec}).WithANNSField(common.DefaultFloatVecFieldName)) - common.CheckErr(t, err, true) - require.EqualValues(t, 1, len(searchRes)) - searchIDs := searchRes[0].IDs.(*column.ColumnInt64).Data() - require.EqualValues(t, expectedValidCount, len(searchIDs), "search should return all 50 valid vectors") + // Search verify + searchVec := entity.FloatVector(common.GenFloatVector(common.DefaultDim)) + searchRes, err := mc.Search(ctx, client.NewSearchOption(collName, 100, []entity.Vector{searchVec}).WithANNSField(common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, true) + require.EqualValues(t, 1, len(searchRes)) + require.EqualValues(t, expectedValidCount, searchRes[0].ResultCount, "%s: search should return %d valid vectors", context, expectedValidCount) - verifyVectorData := func(queryResult client.ResultSet, context string) { - pkCol := queryResult.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64) - vecCol := queryResult.GetColumn(common.DefaultFloatVecFieldName).(*column.ColumnFloatVector) - for i := 0; i < queryResult.ResultCount; i++ { + // Query verify all rows + queryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(actualPkData))).WithOutputFields("scalar", common.DefaultFloatVecFieldName)) + common.CheckErr(t, err, true) + require.EqualValues(t, nb, queryRes.ResultCount, "%s: query should return %d rows", context, nb) + + pkCol := queryRes.GetColumn(common.DefaultInt64FieldName).(*column.ColumnInt64) + scalarCol := queryRes.GetColumn("scalar").(*column.ColumnInt32) + vecCol := queryRes.GetColumn(common.DefaultFloatVecFieldName).(*column.ColumnFloatVector) + + for i := 0; i < queryRes.ResultCount; i++ { pk, _ := pkCol.GetAsInt64(i) + scalarVal, _ := scalarCol.GetAsInt64(i) isNull, _ := vecCol.IsNull(i) - expectedVec, exists := expectedVectorMap[pk] - require.True(t, exists, "%s: unexpected PK %d in query results", context, pk) + expectedScalar := expectedScalarMap[pk] + require.EqualValues(t, expectedScalar, scalarVal, "%s: scalar mismatch for pk %d", context, pk) + expectedVec := expectedVectorMap[pk] if expectedVec != nil { require.False(t, isNull, "%s: vector should not be null for pk %d", context, pk) vecData, _ := vecCol.Get(i) queriedVec := []float32(vecData.(entity.FloatVector)) - require.EqualValues(t, common.DefaultDim, len(queriedVec), "%s: vector dimension should match for pk %d", context, pk) for j := range expectedVec { - require.InDelta(t, expectedVec[j], queriedVec[j], 1e-6, "%s: vector element %d should match for pk %d", context, j, pk) + require.InDelta(t, expectedVec[j], queriedVec[j], 1e-6, "%s: vector element %d mismatch for pk %d", context, j, pk) } } else { require.True(t, isNull, "%s: vector should be null for pk %d", context, pk) - vecData, _ := vecCol.Get(i) - require.Nil(t, vecData, "%s: null vector data should be nil for pk %d", context, pk) } } } - searchQueryRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(searchIDs))).WithOutputFields(common.DefaultFloatVecFieldName)) - common.CheckErr(t, err, true) - verifyVectorData(searchQueryRes, "All valid vectors after upsert") + // Upsert 1: Update all 100 rows to null vectors + upsert1PkData := make([]int64, nb) + for i := range nb { + upsert1PkData[i] = actualPkData[i] + } + upsert1PkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsert1PkData) - upsertedToValidPKs := upsertedPks[0:25] - queryUpsertedRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("int64 in [%s]", int64SliceToString(upsertedToValidPKs))).WithOutputFields(common.DefaultFloatVecFieldName)) + allNullValidData := make([]bool, nb) + upsert1VecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, [][]float32{}, allNullValidData) common.CheckErr(t, err, true) - require.EqualValues(t, 25, queryUpsertedRes.ResultCount, "should have 25 rows for upserted to valid") - verifyVectorData(queryUpsertedRes, "Upserted valid rows") - countRes, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter("").WithOutputFields("count(*)")) + upsert1ScalarData := make([]int32, nb) + upsert1ScalarValidData := make([]bool, nb) + for i := range nb { + upsert1ScalarData[i] = int32(i * 100) + upsert1ScalarValidData[i] = true + } + upsert1ScalarColumn, err := column.NewNullableColumnInt32("scalar", upsert1ScalarData, upsert1ScalarValidData) common.CheckErr(t, err, true) - totalCount, err := countRes.Fields[0].GetAsInt64(0) + + upsertRes1, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsert1PkColumn, upsert1ScalarColumn, upsert1VecColumn)) common.CheckErr(t, err, true) - require.EqualValues(t, nb, totalCount, "total count after upsert should still be %d", nb) + require.EqualValues(t, nb, upsertRes1.UpsertCount) + + // For AutoID=true, upsert returns new IDs, need to update actualPkData + if autoID { + upsertedIDs := upsertRes1.IDs.(*column.ColumnInt64) + newPkData := upsertedIDs.Data() + // Clear old expected state + expectedVectorMap = make(map[int64][]float32) + expectedScalarMap = make(map[int64]int32) + // Update actualPkData with new IDs + actualPkData = newPkData + } + + // Update expected state: all vectors null + for i := range nb { + expectedVectorMap[actualPkData[i]] = nil + expectedScalarMap[actualPkData[i]] = upsert1ScalarData[i] + } + + // Verify after upsert 1: search should return 0 (all null) + flushAndVerify(0, "After Upsert1-AllNull") + + // Upsert 2: Update rows nullPercent to nb-1 to valid vectors + upsert2Nb := nb - nullPercent + upsert2PkData := make([]int64, upsert2Nb) + for i := range upsert2Nb { + upsert2PkData[i] = actualPkData[i+nullPercent] + } + upsert2PkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsert2PkData) + + upsert2Vectors := make([][]float32, upsert2Nb) + upsert2ValidData := make([]bool, upsert2Nb) + for i := range upsert2Nb { + vec := make([]float32, common.DefaultDim) + for j := range common.DefaultDim { + vec[j] = float32((i+500)*common.DefaultDim+j) / 10000.0 + } + upsert2Vectors[i] = vec + upsert2ValidData[i] = true + } + upsert2VecColumn, err := column.NewNullableColumnFloatVector(common.DefaultFloatVecFieldName, common.DefaultDim, upsert2Vectors, upsert2ValidData) + common.CheckErr(t, err, true) + + upsert2ScalarData := make([]int32, upsert2Nb) + upsert2ScalarValidData := make([]bool, upsert2Nb) + for i := range upsert2Nb { + upsert2ScalarData[i] = int32((i + nullPercent) * 200) + upsert2ScalarValidData[i] = true + } + upsert2ScalarColumn, err := column.NewNullableColumnInt32("scalar", upsert2ScalarData, upsert2ScalarValidData) + common.CheckErr(t, err, true) + + upsertRes2, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsert2PkColumn, upsert2ScalarColumn, upsert2VecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, upsert2Nb, upsertRes2.UpsertCount) + + // For AutoID=true, upsert returns new IDs for the upserted rows + if autoID { + upsertedIDs := upsertRes2.IDs.(*column.ColumnInt64) + newPkData := upsertedIDs.Data() + // Update actualPkData for rows nullPercent to nb-1 + for i := range upsert2Nb { + // Remove old expected state + delete(expectedVectorMap, actualPkData[i+nullPercent]) + delete(expectedScalarMap, actualPkData[i+nullPercent]) + // Update to new PK + actualPkData[i+nullPercent] = newPkData[i] + } + } + + // Update expected state: rows nullPercent to nb-1 now have valid vectors + for i := range upsert2Nb { + expectedVectorMap[actualPkData[i+nullPercent]] = upsert2Vectors[i] + expectedScalarMap[actualPkData[i+nullPercent]] = upsert2ScalarData[i] + } + + // Verify after upsert 2: search should return upsert2Nb (rows nullPercent to nb-1 valid) + flushAndVerify(upsert2Nb, "After Upsert2-NullToValid") + + // Upsert 3: Partial update rows 0 to nullPercent-1 (only scalar), vector preserved (still null) + upsert3Nb := nullPercent + upsert3PkData := make([]int64, upsert3Nb) + upsert3ScalarData := make([]int32, upsert3Nb) + upsert3ScalarValidData := make([]bool, upsert3Nb) + for i := range upsert3Nb { + upsert3PkData[i] = actualPkData[i] + upsert3ScalarData[i] = int32(i * 1000) + upsert3ScalarValidData[i] = true + } + upsert3PkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, upsert3PkData) + upsert3ScalarColumn, err := column.NewNullableColumnInt32("scalar", upsert3ScalarData, upsert3ScalarValidData) + common.CheckErr(t, err, true) + + upsertRes3, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName, upsert3PkColumn, upsert3ScalarColumn).WithPartialUpdate(true)) + common.CheckErr(t, err, true) + require.EqualValues(t, upsert3Nb, upsertRes3.UpsertCount) + + // For AutoID=true, upsert returns new IDs for the upserted rows + if autoID { + upsertedIDs := upsertRes3.IDs.(*column.ColumnInt64) + newPkData := upsertedIDs.Data() + // Update actualPkData for rows 0 to nullPercent-1 + for i := range upsert3Nb { + // Remove old expected state + delete(expectedVectorMap, actualPkData[i]) + delete(expectedScalarMap, actualPkData[i]) + // Update to new PK + actualPkData[i] = newPkData[i] + } + } + + // Update expected state: rows 0 to nullPercent-1 scalar updated, vector preserved (null) + for i := range upsert3Nb { + expectedScalarMap[actualPkData[i]] = upsert3ScalarData[i] + // Vector remains null (preserved from before) + expectedVectorMap[actualPkData[i]] = nil + } + + // Verify after upsert 3: search should still return upsert2Nb (only rows nullPercent to nb-1 valid) + flushAndVerify(upsert2Nb, "After Upsert3-PartialUpdate") // clean up err = mc.DropCollection(ctx, client.NewDropCollectionOption(collName))