enhance: Optimize partial update merge logic by unifying nullable format (#44197)

issue: #43980
This commit optimizes the partial update merge logic by standardizing
nullable field representation before merge operations to avoid corner
cases during the merge process.

Key changes:
- Unify nullable field data format to FULL FORMAT before merge execution
- Add extensive unit tests for bounds checking and edge cases

The optimization ensures:
- Consistent nullable field representation across SDK and internal
- Proper handling of null values during merge operations
- Prevention of index out-of-bounds errors in vector field updates
- Better error handling and validation for partial update scenarios

This resolves issues where different nullable field formats could cause
merge failures or data corruption during partial update operations.

Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
wei liu 2025-09-10 17:27:56 +08:00 committed by GitHub
parent 68fb357515
commit 18371773dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 1463 additions and 171 deletions

View File

@ -296,17 +296,39 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
if fieldData.GetIsDynamic() { if fieldData.GetIsDynamic() {
fieldName = "$meta" fieldName = "$meta"
} }
fieldID, ok := it.schema.MapFieldID(fieldName) fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldName)
if !ok { if err != nil {
log.Info("field not found in schema", zap.Any("field", fieldData)) log.Info("get field schema failed", zap.Error(err))
return merr.WrapErrParameterInvalidMsg("field not found in schema") return err
} }
fieldData.FieldId = fieldID fieldData.FieldId = fieldSchema.GetFieldID()
fieldData.FieldName = fieldName fieldData.FieldName = fieldName
// compatible with different nullable data format from sdk
if len(fieldData.GetValidData()) != 0 {
err := FillWithNullValue(fieldData, fieldSchema, int(it.upsertMsg.InsertMsg.NRows()))
if err != nil {
log.Info("unify null field data format failed", zap.Error(err))
return err
}
}
} }
// Note: the most difficult part is to handle the merge progress of upsert and query result // Two nullable data formats are supported:
// we need to enable merge logic on different length between upsertFieldData and it.insertFieldData //
// COMPRESSED FORMAT (SDK format, before validateUtil.fillWithValue processing):
// Logical data: [1, null, 2]
// Storage: Data=[1, 2] + ValidData=[true, false, true]
// - Data array contains only non-null values (compressed)
// - ValidData array tracks null positions for all rows
//
// FULL FORMAT (Milvus internal format, after validateUtil.fillWithValue processing):
// Logical data: [1, null, 2]
// Storage: Data=[1, 0, 2] + ValidData=[true, false, true]
// - Data array contains values for all rows (nulls filled with zero/default)
// - ValidData array still tracks null positions
//
// Note: we will unify the nullable format to FULL FORMAT before executing the merge logic
insertIdxInUpsert := make([]int, 0) insertIdxInUpsert := make([]int, 0)
updateIdxInUpsert := make([]int, 0) updateIdxInUpsert := make([]int, 0)
// 1. split upsert data into insert and update by query result // 1. split upsert data into insert and update by query result
@ -367,7 +389,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
if !ok { if !ok {
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping") return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
} }
typeutil.AppendFieldDataWithNullData(it.insertFieldData, existFieldData, int64(existIndex), false) typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex))
err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(baseIdx), int64(idx)) err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(baseIdx), int64(idx))
baseIdx += 1 baseIdx += 1
if err != nil { if err != nil {
@ -386,7 +408,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
return lackOfFieldErr return lackOfFieldErr
} }
// if the nullable or default value field is not set in upsert request, which means the len(upsertFieldData) < len(it.insertFieldData) // if the nullable field has not passed in upsert request, which means the len(upsertFieldData) < len(it.insertFieldData)
// we need to generate the nullable field data before append as insert // we need to generate the nullable field data before append as insert
insertWithNullField := make([]*schemapb.FieldData, 0) insertWithNullField := make([]*schemapb.FieldData, 0)
upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (string, *schemapb.FieldData) { upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (string, *schemapb.FieldData) {
@ -407,27 +429,27 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
} }
} }
for _, idx := range insertIdxInUpsert { for _, idx := range insertIdxInUpsert {
typeutil.AppendFieldDataWithNullData(it.insertFieldData, insertWithNullField, int64(idx), true) typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx))
} }
} }
// 4. clean field data with valid data after merge upsert and query result
for _, fieldData := range it.insertFieldData { for _, fieldData := range it.insertFieldData {
// Note: Since protobuf cannot correctly identify null values, zero values + valid data are used to identify null values, if len(fieldData.GetValidData()) > 0 {
// therefore for field data obtained from query results, if the field is nullable, it needs to clean zero values err := ToCompressedFormatNullable(fieldData)
if len(fieldData.GetValidData()) != 0 && getValidNumber(fieldData.GetValidData()) != len(fieldData.GetValidData()) {
err := ResetNullFieldData(fieldData)
if err != nil { if err != nil {
log.Info("reset null field data failed", zap.Error(err)) log.Info("convert to compressed format nullable failed", zap.Error(err))
return err return err
} }
} }
} }
return nil return nil
} }
func ResetNullFieldData(field *schemapb.FieldData) error { // ToCompressedFormatNullable converts the field data from full format nullable to compressed format nullable
func ToCompressedFormatNullable(field *schemapb.FieldData) error {
if getValidNumber(field.GetValidData()) == len(field.GetValidData()) {
return nil
}
switch field.Field.(type) { switch field.Field.(type) {
case *schemapb.FieldData_Scalars: case *schemapb.FieldData_Scalars:
switch sd := field.GetScalars().GetData().(type) { switch sd := field.GetScalars().GetData().(type) {
@ -529,6 +551,20 @@ func ResetNullFieldData(field *schemapb.FieldData) error {
sd.JsonData.Data = ret sd.JsonData.Data = ret
} }
case *schemapb.ScalarField_ArrayData:
validRowNum := getValidNumber(field.GetValidData())
if validRowNum == 0 {
sd.ArrayData.Data = make([]*schemapb.ScalarField, 0)
} else {
ret := make([]*schemapb.ScalarField, 0, validRowNum)
for i, valid := range field.GetValidData() {
if valid {
ret = append(ret, sd.ArrayData.Data[i])
}
}
sd.ArrayData.Data = ret
}
default: default:
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String())) return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
} }
@ -540,6 +576,7 @@ func ResetNullFieldData(field *schemapb.FieldData) error {
return nil return nil
} }
// GenNullableFieldData generates nullable field data in FULL FORMAT
func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schemapb.FieldData, error) { func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schemapb.FieldData, error) {
switch field.DataType { switch field.DataType {
case schemapb.DataType_Bool: case schemapb.DataType_Bool:
@ -668,6 +705,24 @@ func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schem
}, },
}, nil }, nil
case schemapb.DataType_Array:
return &schemapb.FieldData{
FieldId: field.FieldID,
FieldName: field.Name,
Type: field.DataType,
IsDynamic: field.IsDynamic,
ValidData: make([]bool, upsertIDSize),
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: make([]*schemapb.ScalarField, upsertIDSize),
},
},
},
},
}, nil
default: default:
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined scalar data type:%s", field.DataType.String())) return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined scalar data type:%s", field.DataType.String()))
} }

View File

@ -340,12 +340,12 @@ func (v *validateUtil) fillWithValue(data []*schemapb.FieldData, schema *typeuti
} }
if fieldSchema.GetDefaultValue() == nil { if fieldSchema.GetDefaultValue() == nil {
err = v.fillWithNullValue(field, fieldSchema, numRows) err = FillWithNullValue(field, fieldSchema, numRows)
if err != nil { if err != nil {
return err return err
} }
} else { } else {
err = v.fillWithDefaultValue(field, fieldSchema, numRows) err = FillWithDefaultValue(field, fieldSchema, numRows)
if err != nil { if err != nil {
return err return err
} }
@ -355,7 +355,7 @@ func (v *validateUtil) fillWithValue(data []*schemapb.FieldData, schema *typeuti
return nil return nil
} }
func (v *validateUtil) fillWithNullValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error { func FillWithNullValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error {
err := nullutil.CheckValidData(field.GetValidData(), fieldSchema, numRows) err := nullutil.CheckValidData(field.GetValidData(), fieldSchema, numRows)
if err != nil { if err != nil {
return err return err
@ -434,7 +434,7 @@ func (v *validateUtil) fillWithNullValue(field *schemapb.FieldData, fieldSchema
return nil return nil
} }
func (v *validateUtil) fillWithDefaultValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error { func FillWithDefaultValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error {
var err error var err error
switch field.Field.(type) { switch field.Field.(type) {
case *schemapb.FieldData_Scalars: case *schemapb.FieldData_Scalars:

View File

@ -755,26 +755,6 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap
} }
func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) { func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) {
return AppendFieldDataWithNullData(dst, src, idx, false)
}
// AppendFieldData appends field data of specified index from src to dst
//
// Note: The field data in src may have two different nullable formats, so the caller
// must specify how to handle null values using the skipAppendNullData parameter.
//
// Two nullable data formats are supported:
//
// Case 1: Before validateUtil.fillWithValue processing
// Data: [1, null, 2] = [1, 2] + [true, false, true]
// Set skipAppendNullData = true to skip appending values to data array of dst
//
// Case 2: After validateUtil.fillWithValue processing
// Data: [1, null, 2] = [1, 0, 2] + [true, false, true]
// Set skipAppendNullData = false to append zero values to data array of dst
//
// TODO: Unify nullable format - SDK uses Case 1, Milvus uses Case 2
func AppendFieldDataWithNullData(dst, src []*schemapb.FieldData, idx int64, skipAppendNullData bool) (appendSize int64) {
dstMap := make(map[int64]*schemapb.FieldData) dstMap := make(map[int64]*schemapb.FieldData)
for _, fieldData := range dst { for _, fieldData := range dst {
if fieldData != nil { if fieldData != nil {
@ -799,10 +779,6 @@ func AppendFieldDataWithNullData(dst, src []*schemapb.FieldData, idx int64, skip
} }
valid := fieldData.ValidData[idx] valid := fieldData.ValidData[idx]
dstFieldData.ValidData = append(dstFieldData.ValidData, valid) dstFieldData.ValidData = append(dstFieldData.ValidData, valid)
if !valid && skipAppendNullData {
continue
}
} }
switch fieldType := fieldData.Field.(type) { switch fieldType := fieldData.Field.(type) {
case *schemapb.FieldData_Scalars: case *schemapb.FieldData_Scalars:
@ -1108,25 +1084,10 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
continue continue
} }
updateFieldIdx := updateIdx
// Update ValidData if present // Update ValidData if present
if len(updateFieldData.GetValidData()) != 0 { if len(updateFieldData.GetValidData()) != 0 {
if len(baseFieldData.GetValidData()) != 0 { if len(baseFieldData.GetValidData()) != 0 {
baseFieldData.ValidData[baseIdx] = updateFieldData.ValidData[updateFieldIdx] baseFieldData.ValidData[baseIdx] = updateFieldData.ValidData[updateIdx]
}
// update field data to null, only modify valid data
if !updateFieldData.ValidData[updateFieldIdx] {
continue
}
// for nullable field data, such as data=[1,1], valid_data=[true, false, true]
// should update the updateFieldIdx to the expected valid index
updateFieldIdx = 0
for _, validData := range updateFieldData.GetValidData()[:updateIdx] {
if validData {
updateFieldIdx += 1
}
} }
} }
@ -1140,43 +1101,58 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
switch baseScalar.Data.(type) { switch baseScalar.Data.(type) {
case *schemapb.ScalarField_BoolData: case *schemapb.ScalarField_BoolData:
updateData := updateScalar.GetBoolData() updateData := updateScalar.GetBoolData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { baseData := baseScalar.GetBoolData()
baseScalar.GetBoolData().Data[baseIdx] = updateData.Data[updateFieldIdx] if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
baseData.Data[baseIdx] = updateData.Data[updateIdx]
} }
case *schemapb.ScalarField_IntData: case *schemapb.ScalarField_IntData:
updateData := updateScalar.GetIntData() updateData := updateScalar.GetIntData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { baseData := baseScalar.GetIntData()
baseScalar.GetIntData().Data[baseIdx] = updateData.Data[updateFieldIdx] if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
baseData.Data[baseIdx] = updateData.Data[updateIdx]
} }
case *schemapb.ScalarField_LongData: case *schemapb.ScalarField_LongData:
updateData := updateScalar.GetLongData() updateData := updateScalar.GetLongData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { baseData := baseScalar.GetLongData()
baseScalar.GetLongData().Data[baseIdx] = updateData.Data[updateFieldIdx] if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
baseData.Data[baseIdx] = updateData.Data[updateIdx]
} }
case *schemapb.ScalarField_FloatData: case *schemapb.ScalarField_FloatData:
updateData := updateScalar.GetFloatData() updateData := updateScalar.GetFloatData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { baseData := baseScalar.GetFloatData()
baseScalar.GetFloatData().Data[baseIdx] = updateData.Data[updateFieldIdx] if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
baseData.Data[baseIdx] = updateData.Data[updateIdx]
} }
case *schemapb.ScalarField_DoubleData: case *schemapb.ScalarField_DoubleData:
updateData := updateScalar.GetDoubleData() updateData := updateScalar.GetDoubleData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { baseData := baseScalar.GetDoubleData()
baseScalar.GetDoubleData().Data[baseIdx] = updateData.Data[updateFieldIdx] if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
baseData.Data[baseIdx] = updateData.Data[updateIdx]
} }
case *schemapb.ScalarField_StringData: case *schemapb.ScalarField_StringData:
updateData := updateScalar.GetStringData() updateData := updateScalar.GetStringData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { baseData := baseScalar.GetStringData()
baseScalar.GetStringData().Data[baseIdx] = updateData.Data[updateFieldIdx] if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
baseData.Data[baseIdx] = updateData.Data[updateIdx]
} }
case *schemapb.ScalarField_ArrayData: case *schemapb.ScalarField_ArrayData:
updateData := updateScalar.GetArrayData() updateData := updateScalar.GetArrayData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { baseData := baseScalar.GetArrayData()
baseScalar.GetArrayData().Data[baseIdx] = updateData.Data[updateFieldIdx] if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
baseData.Data[baseIdx] = updateData.Data[updateIdx]
} }
case *schemapb.ScalarField_JsonData: case *schemapb.ScalarField_JsonData:
updateData := updateScalar.GetJsonData() updateData := updateScalar.GetJsonData()
baseData := baseScalar.GetJsonData() baseData := baseScalar.GetJsonData()
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) { if updateData != nil && baseData != nil &&
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
if baseFieldData.GetIsDynamic() { if baseFieldData.GetIsDynamic() {
// dynamic field is a json with only 1 level nested struct, // dynamic field is a json with only 1 level nested struct,
// so we need to unmarshal and iterate updateData's key value, and update the baseData's key value // so we need to unmarshal and iterate updateData's key value, and update the baseData's key value
@ -1186,7 +1162,7 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
if err := json.Unmarshal(baseData.Data[baseIdx], &baseMap); err != nil { if err := json.Unmarshal(baseData.Data[baseIdx], &baseMap); err != nil {
return fmt.Errorf("failed to unmarshal base json: %v", err) return fmt.Errorf("failed to unmarshal base json: %v", err)
} }
if err := json.Unmarshal(updateData.Data[updateFieldIdx], &updateMap); err != nil { if err := json.Unmarshal(updateData.Data[updateIdx], &updateMap); err != nil {
return fmt.Errorf("failed to unmarshal update json: %v", err) return fmt.Errorf("failed to unmarshal update json: %v", err)
} }
// merge // merge
@ -1200,7 +1176,7 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
} }
baseScalar.GetJsonData().Data[baseIdx] = newJSON baseScalar.GetJsonData().Data[baseIdx] = newJSON
} else { } else {
baseScalar.GetJsonData().Data[baseIdx] = updateData.Data[updateFieldIdx] baseScalar.GetJsonData().Data[baseIdx] = updateData.Data[updateIdx]
} }
} }
default: default:
@ -1216,73 +1192,71 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
switch baseVector.Data.(type) { switch baseVector.Data.(type) {
case *schemapb.VectorField_BinaryVector: case *schemapb.VectorField_BinaryVector:
updateData := updateVector.GetBinaryVector()
if updateData != nil {
baseData := baseVector.GetBinaryVector() baseData := baseVector.GetBinaryVector()
baseStartIdx := updateFieldIdx * (dim / 8) updateData := updateVector.GetBinaryVector()
baseEndIdx := (updateFieldIdx + 1) * (dim / 8) if baseData != nil && updateData != nil {
updateStartIdx := updateFieldIdx * (dim / 8) baseStartIdx := baseIdx * (dim / 8)
updateEndIdx := (updateFieldIdx + 1) * (dim / 8) baseEndIdx := (baseIdx + 1) * (dim / 8)
updateStartIdx := updateIdx * (dim / 8)
updateEndIdx := (updateIdx + 1) * (dim / 8)
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) { if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx]) copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
} }
} }
case *schemapb.VectorField_FloatVector: case *schemapb.VectorField_FloatVector:
updateData := updateVector.GetFloatVector()
if updateData != nil {
baseData := baseVector.GetFloatVector() baseData := baseVector.GetFloatVector()
baseStartIdx := updateFieldIdx * dim updateData := updateVector.GetFloatVector()
baseEndIdx := (updateFieldIdx + 1) * dim if baseData != nil && updateData != nil {
updateStartIdx := updateFieldIdx * dim baseStartIdx := baseIdx * dim
updateEndIdx := (updateFieldIdx + 1) * dim baseEndIdx := (baseIdx + 1) * dim
updateStartIdx := updateIdx * dim
updateEndIdx := (updateIdx + 1) * dim
if int(updateEndIdx) <= len(updateData.Data) && int(baseEndIdx) <= len(baseData.Data) { if int(updateEndIdx) <= len(updateData.Data) && int(baseEndIdx) <= len(baseData.Data) {
copy(baseData.Data[baseStartIdx:baseEndIdx], updateData.Data[updateStartIdx:updateEndIdx]) copy(baseData.Data[baseStartIdx:baseEndIdx], updateData.Data[updateStartIdx:updateEndIdx])
} }
} }
case *schemapb.VectorField_Float16Vector: case *schemapb.VectorField_Float16Vector:
updateData := updateVector.GetFloat16Vector()
if updateData != nil {
baseData := baseVector.GetFloat16Vector() baseData := baseVector.GetFloat16Vector()
baseStartIdx := updateFieldIdx * (dim * 2) updateData := updateVector.GetFloat16Vector()
baseEndIdx := (updateFieldIdx + 1) * (dim * 2) if baseData != nil && updateData != nil {
updateStartIdx := updateFieldIdx * (dim * 2) baseStartIdx := baseIdx * (dim * 2)
updateEndIdx := (updateFieldIdx + 1) * (dim * 2) baseEndIdx := (baseIdx + 1) * (dim * 2)
updateStartIdx := updateIdx * (dim * 2)
updateEndIdx := (updateIdx + 1) * (dim * 2)
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) { if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx]) copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
} }
} }
case *schemapb.VectorField_Bfloat16Vector: case *schemapb.VectorField_Bfloat16Vector:
updateData := updateVector.GetBfloat16Vector()
if updateData != nil {
baseData := baseVector.GetBfloat16Vector() baseData := baseVector.GetBfloat16Vector()
baseStartIdx := updateFieldIdx * (dim * 2) updateData := updateVector.GetBfloat16Vector()
baseEndIdx := (updateFieldIdx + 1) * (dim * 2) if baseData != nil && updateData != nil {
updateStartIdx := updateFieldIdx * (dim * 2) baseStartIdx := baseIdx * (dim * 2)
updateEndIdx := (updateFieldIdx + 1) * (dim * 2) baseEndIdx := (baseIdx + 1) * (dim * 2)
updateStartIdx := updateIdx * (dim * 2)
updateEndIdx := (updateIdx + 1) * (dim * 2)
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) { if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx]) copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
} }
} }
case *schemapb.VectorField_SparseFloatVector: case *schemapb.VectorField_SparseFloatVector:
updateData := updateVector.GetSparseFloatVector()
if updateData != nil && int(updateFieldIdx) < len(updateData.Contents) {
baseData := baseVector.GetSparseFloatVector() baseData := baseVector.GetSparseFloatVector()
if int(updateFieldIdx) < len(baseData.Contents) { updateData := updateVector.GetSparseFloatVector()
baseData.Contents[updateFieldIdx] = updateData.Contents[updateFieldIdx] if baseData != nil && updateData != nil && int(baseIdx) < len(baseData.Contents) && int(updateIdx) < len(updateData.Contents) {
baseData.Contents[baseIdx] = updateData.Contents[updateIdx]
// Update dimension if necessary // Update dimension if necessary
if updateData.Dim > baseData.Dim { if updateData.Dim > baseData.Dim {
baseData.Dim = updateData.Dim baseData.Dim = updateData.Dim
} }
} }
}
case *schemapb.VectorField_Int8Vector: case *schemapb.VectorField_Int8Vector:
updateData := updateVector.GetInt8Vector()
if updateData != nil {
baseData := baseVector.GetInt8Vector() baseData := baseVector.GetInt8Vector()
baseStartIdx := updateFieldIdx * dim updateData := updateVector.GetInt8Vector()
baseEndIdx := (updateFieldIdx + 1) * dim if baseData != nil && updateData != nil {
updateStartIdx := updateFieldIdx * dim baseStartIdx := baseIdx * dim
updateEndIdx := (updateFieldIdx + 1) * dim baseEndIdx := (baseIdx + 1) * dim
updateStartIdx := updateIdx * dim
updateEndIdx := (updateIdx + 1) * dim
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) { if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx]) copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
} }

File diff suppressed because it is too large Load Diff

View File

@ -3931,3 +3931,103 @@ def parse_fmod(x: int, y: int) -> int:
v = abs(x) % abs(y) v = abs(x) % abs(y)
return v if x >= 0 else -v return v if x >= 0 else -v
def gen_partial_row_data_by_schema(nb=ct.default_nb, schema=None, desired_field_names=None, num_fields=1,
start=0, random_pk=False, skip_field_names=[]):
"""
Generate row data that contains a subset of fields from the given schema.
Args:
schema: Collection schema or collection info dict. If None, uses default schema.
desired_field_names (list[str] | None): Explicit field names to include (intersected with eligible fields).
num_fields (int): Number of fields to include if desired_field_names is not provided. Defaults to 1.
start (int): Starting value for primary key fields when sequential values are needed.
random_pk (bool): Whether to generate random primary key values.
skip_field_names (list[str]): Field names to skip.
nb (int): Number of rows to generate. Defaults to 1.
Returns:
list[dict]: a list of rows.
Notes:
- Skips auto_id fields and function output fields.
- Primary INT64/VARCHAR fields get sequential values from `start` unless `random_pk=True`.
- Works with both schema dicts (from v2 client describe_collection) and ORM schema objects.
"""
if schema is None:
schema = gen_default_collection_schema()
func_output_fields = []
# Build list of eligible fields
if isinstance(schema, dict):
fields = schema.get('fields', [])
functions = schema.get('functions', [])
for func in functions:
output_field_names = func.get('output_field_names', [])
func_output_fields.extend(output_field_names)
func_output_fields = list(set(func_output_fields))
eligible_fields = []
for field in fields:
field_name = field.get('name', None)
if field.get('auto_id', False):
continue
if field_name in func_output_fields or field_name in skip_field_names:
continue
eligible_fields.append(field)
# Choose subset
if desired_field_names:
desired_set = set(desired_field_names)
chosen_fields = [f for f in eligible_fields if f.get('name') in desired_set]
else:
n = max(0, min(len(eligible_fields), num_fields if num_fields is not None else 1))
chosen_fields = eligible_fields[:n]
rows = []
curr_start = start
for _ in range(nb):
row = {}
for field in chosen_fields:
fname = field.get('name', None)
value = gen_data_by_collection_field(field, random_pk=random_pk)
# Override for PKs when not random
if not random_pk and field.get('is_primary', False) is True:
if field.get('type', None) == DataType.INT64:
value = curr_start
curr_start += 1
elif field.get('type', None) == DataType.VARCHAR:
value = str(curr_start)
curr_start += 1
row[fname] = value
rows.append(row)
return rows
# ORM schema path
fields = schema.fields
if hasattr(schema, "functions"):
functions = schema.functions
for func in functions:
func_output_fields.extend(func.output_field_names)
func_output_fields = list(set(func_output_fields))
eligible_fields = []
for field in fields:
if field.auto_id:
continue
if field.name in func_output_fields or field.name in skip_field_names:
continue
eligible_fields.append(field)
if desired_field_names:
desired_set = set(desired_field_names)
chosen_fields = [f for f in eligible_fields if f.name in desired_set]
else:
n = max(0, min(len(eligible_fields), num_fields if num_fields is not None else 1))
chosen_fields = eligible_fields[:n]
rows = []
curr_start = start
for _ in range(nb):
row = {}
for field in chosen_fields:
value = gen_data_by_collection_field(field, random_pk=random_pk)
if not random_pk and field.is_primary is True:
if field.dtype == DataType.INT64:
value = curr_start
curr_start += 1
elif field.dtype == DataType.VARCHAR:
value = str(curr_start)
curr_start += 1
row[field.name] = value
rows.append(row)
return rows

View File

@ -0,0 +1,314 @@
import pytest
import time
import random
import numpy as np
from common.common_type import CaseLabel, CheckTasks
from common import common_func as cf
from common import common_type as ct
from utils.util_log import test_log as log
from utils.util_pymilvus import *
from base.client_v2_base import TestMilvusClientV2Base
from pymilvus import DataType, FieldSchema, CollectionSchema
from sklearn import preprocessing
# Test parameters
default_nb = ct.default_nb
default_nq = ct.default_nq
default_limit = ct.default_limit
default_search_exp = "id >= 0"
exp_res = "exp_res"
default_primary_key_field_name = "id"
default_vector_field_name = "vector"
default_int32_field_name = ct.default_int32_field_name
class TestMilvusClientPartialUpdate(TestMilvusClientV2Base):
""" Test case of partial update functionality """
@pytest.mark.tags(CaseLabel.L0)
def test_partial_update_all_field_types(self):
"""
Test partial update functionality with all field types
1. Create collection with all data types
2. Insert initial data
3. Perform partial update for each field type
4. Verify all updates work correctly
"""
client = self._client()
dim = 64
collection_name = cf.gen_collection_name_by_testcase_name()
# Create schema with all data types
schema = cf.gen_all_datatype_collection_schema(dim=dim)
# Create index parameters
index_params = client.prepare_index_params()
for i in range(len(schema.fields)):
field_name = schema.fields[i].name
print(f"field_name: {field_name}")
if field_name == "json_field":
index_params.add_index(field_name, index_type="AUTOINDEX",
params={"json_cast_type": "json"})
elif field_name == "text_sparse_emb":
index_params.add_index(field_name, index_type="AUTOINDEX", metric_type="BM25")
else:
index_params.add_index(field_name, index_type="AUTOINDEX")
# Create collection
client.create_collection(collection_name, default_dim, consistency_level="Strong", schema=schema, index_params=index_params)
# Load collection
self.load_collection(client, collection_name)
# Insert initial data
nb = 1000
rows = cf.gen_row_data_by_schema(nb=nb, schema=schema)
self.upsert(client, collection_name, rows, partial_update=True)
log.info(f"Inserted {nb} initial records")
primary_key_field_name = schema.fields[0].name
for i in range(len(schema.fields)):
update_field_name = schema.fields[i if i != 0 else 1].name
new_row = cf.gen_partial_row_data_by_schema(nb=nb, schema=schema,
desired_field_names=[primary_key_field_name, update_field_name])
client.upsert(collection_name, new_row, partial_update=True)
log.info("Partial update test for all field types passed successfully")
@pytest.mark.tags(CaseLabel.L0)
def test_partial_update_simple_demo(self):
"""
Test simple partial update demo with nullable fields
1. Create collection with explicit schema including nullable fields
2. Insert initial data with some null values
3. Perform partial updates with different field combinations
4. Verify partial update behavior preserves unchanged fields
"""
client = self._client()
dim = 3
collection_name = cf.gen_collection_name_by_testcase_name()
# Create schema with nullable fields
schema = self.create_schema(client, enable_dynamic_field=False)[0]
schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False)
schema.add_field("vector", DataType.FLOAT_VECTOR, dim=dim)
schema.add_field("name", DataType.VARCHAR, max_length=100, nullable=True)
schema.add_field("price", DataType.FLOAT, nullable=True)
schema.add_field("category", DataType.VARCHAR, max_length=50, nullable=True)
# Create collection
self.create_collection(client, collection_name, schema=schema)
# Create index
index_params = self.prepare_index_params(client)[0]
index_params.add_index("vector", index_type="AUTOINDEX", metric_type="L2")
self.create_index(client, collection_name, index_params=index_params)
# Load collection
self.load_collection(client, collection_name)
# Insert initial data with some null values
initial_data = [
{
"id": 1,
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
"name": "Product A",
"price": 100.0,
"category": "Electronics"
},
{
"id": 2,
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
"name": "Product B",
"price": None, # Null price
"category": "Home"
},
{
"id": 3,
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
"name": "Product C",
"price": None, # Null price
"category": "Books"
}
]
self.upsert(client, collection_name, initial_data, partial_update=False)
log.info("Inserted initial data with null values")
# Verify initial state
results = self.query(client, collection_name, filter="id > 0", output_fields=["*"])[0]
assert len(results) == 3
initial_data_map = {data['id']: data for data in results}
assert initial_data_map[1]['name'] == "Product A"
assert initial_data_map[1]['price'] == 100.0
assert initial_data_map[1]['category'] == "Electronics"
assert initial_data_map[2]['name'] == "Product B"
assert initial_data_map[2]['price'] is None
assert initial_data_map[2]['category'] == "Home"
assert initial_data_map[3]['name'] == "Product C"
assert initial_data_map[3]['price'] is None
assert initial_data_map[3]['category'] == "Books"
log.info("Initial data verification passed")
# First partial update - update all fields
log.info("First partial update - updating all fields...")
first_update_data = [
{
"id": 1,
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
"name": "Product A-Update",
"price": 111.1,
"category": "Electronics-Update"
},
{
"id": 2,
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
"name": "Product B-Update",
"price": 222.2,
"category": "Home-Update"
},
{
"id": 3,
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
"name": "Product C-Update",
"price": None, # Still null
"category": "Books-Update"
}
]
self.upsert(client, collection_name, first_update_data, partial_update=True)
# Verify first update
results = self.query(client, collection_name, filter="id > 0", output_fields=["*"])[0]
assert len(results) == 3
first_update_map = {data['id']: data for data in results}
assert first_update_map[1]['name'] == "Product A-Update"
assert abs(first_update_map[1]['price'] - 111.1) < 0.001
assert first_update_map[1]['category'] == "Electronics-Update"
assert first_update_map[2]['name'] == "Product B-Update"
assert abs(first_update_map[2]['price'] - 222.2) < 0.001
assert first_update_map[2]['category'] == "Home-Update"
assert first_update_map[3]['name'] == "Product C-Update"
assert first_update_map[3]['price'] is None
assert first_update_map[3]['category'] == "Books-Update"
log.info("First partial update verification passed")
# Second partial update - update only specific fields
log.info("Second partial update - updating specific fields...")
second_update_data = [
{
"id": 1,
"name": "Product A-Update-Again",
"price": 1111.1,
"category": "Electronics-Update-Again"
},
{
"id": 2,
"name": "Product B-Update-Again",
"price": None, # Set back to null
"category": "Home-Update-Again"
},
{
"id": 3,
"name": "Product C-Update-Again",
"price": 3333.3, # Set price from null to value
"category": "Books-Update-Again"
}
]
self.upsert(client, collection_name, second_update_data, partial_update=True)
# Verify second update
results = self.query(client, collection_name, filter="id > 0", output_fields=["*"])[0]
assert len(results) == 3
second_update_map = {data['id']: data for data in results}
# Verify ID 1: all fields updated
assert second_update_map[1]['name'] == "Product A-Update-Again"
assert abs(second_update_map[1]['price'] - 1111.1) < 0.001
assert second_update_map[1]['category'] == "Electronics-Update-Again"
# Verify ID 2: all fields updated, price set to null
assert second_update_map[2]['name'] == "Product B-Update-Again"
assert second_update_map[2]['price'] is None
assert second_update_map[2]['category'] == "Home-Update-Again"
# Verify ID 3: all fields updated, price set from null to value
assert second_update_map[3]['name'] == "Product C-Update-Again"
assert abs(second_update_map[3]['price'] - 3333.3) < 0.001
assert second_update_map[3]['category'] == "Books-Update-Again"
# Verify vector fields were preserved from first update (not updated in second update)
# Note: Vector comparison might be complex, so we just verify they exist
assert 'vector' in second_update_map[1]
assert 'vector' in second_update_map[2]
assert 'vector' in second_update_map[3]
log.info("Second partial update verification passed")
log.info("Simple partial update demo test completed successfully")
@pytest.mark.tags(CaseLabel.L0)
def test_milvus_client_partial_update_null_to_null(self):
"""
Target: test PU can successfully update a null to null
Method:
1. Create a collection, enable nullable fields
2. Insert default_nb rows to the collection
3. Partial Update the nullable field with null
4. Query the collection to check the value of nullable field
Expected: query should have correct value and number of entities
"""
# step 1: create collection with nullable fields
client = self._client()
schema = self.create_schema(client, enable_dynamic_field=False)[0]
schema.add_field(default_primary_key_field_name, DataType.INT64, is_primary=True, auto_id=False)
schema.add_field(default_vector_field_name, DataType.FLOAT_VECTOR, dim=default_dim)
schema.add_field(default_int32_field_name, DataType.INT32, nullable=True)
index_params = self.prepare_index_params(client)[0]
index_params.add_index(default_primary_key_field_name, index_type="AUTOINDEX")
index_params.add_index(default_vector_field_name, index_type="AUTOINDEX")
index_params.add_index(default_int32_field_name, index_type="AUTOINDEX")
collection_name = cf.gen_collection_name_by_testcase_name(module_index=1)
self.create_collection(client, collection_name, default_dim, schema=schema,
consistency_level="Strong", index_params=index_params)
# step 2: insert default_nb rows to the collection
rows = cf.gen_row_data_by_schema(nb=default_nb, schema=schema, skip_field_names=[default_int32_field_name])
self.upsert(client, collection_name, rows, partial_update=True)
# step 3: Partial Update the nullable field with null
new_row = cf.gen_partial_row_data_by_schema(
nb=default_nb,
schema=schema,
desired_field_names=[default_primary_key_field_name, default_int32_field_name],
start=0
)
# Set the nullable field to None
for data in new_row:
data[default_int32_field_name] = None
self.upsert(client, collection_name, new_row, partial_update=True)
# step 4: Query the collection to check the value of nullable field
result = self.query(client, collection_name, filter=default_search_exp,
check_task=CheckTasks.check_query_results,
output_fields=[default_int32_field_name],
check_items={exp_res: new_row,
"with_vec": True,
"pk_name": default_primary_key_field_name})[0]
assert len(result) == default_nb
# Verify that all nullable fields are indeed null
for data in result:
assert data[default_int32_field_name] is None, f"Expected null value for {default_int32_field_name}, got {data[default_int32_field_name]}"
log.info("Partial update null to null test completed successfully")
self.drop_collection(client, collection_name)