mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: Optimize partial update merge logic by unifying nullable format (#44197)
issue: #43980 This commit optimizes the partial update merge logic by standardizing nullable field representation before merge operations to avoid corner cases during the merge process. Key changes: - Unify nullable field data format to FULL FORMAT before merge execution - Add extensive unit tests for bounds checking and edge cases The optimization ensures: - Consistent nullable field representation across SDK and internal - Proper handling of null values during merge operations - Prevention of index out-of-bounds errors in vector field updates - Better error handling and validation for partial update scenarios This resolves issues where different nullable field formats could cause merge failures or data corruption during partial update operations. Signed-off-by: Wei Liu <wei.liu@zilliz.com>
This commit is contained in:
parent
68fb357515
commit
18371773dd
@ -296,17 +296,39 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
|||||||
if fieldData.GetIsDynamic() {
|
if fieldData.GetIsDynamic() {
|
||||||
fieldName = "$meta"
|
fieldName = "$meta"
|
||||||
}
|
}
|
||||||
fieldID, ok := it.schema.MapFieldID(fieldName)
|
fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldName)
|
||||||
if !ok {
|
if err != nil {
|
||||||
log.Info("field not found in schema", zap.Any("field", fieldData))
|
log.Info("get field schema failed", zap.Error(err))
|
||||||
return merr.WrapErrParameterInvalidMsg("field not found in schema")
|
return err
|
||||||
}
|
}
|
||||||
fieldData.FieldId = fieldID
|
fieldData.FieldId = fieldSchema.GetFieldID()
|
||||||
fieldData.FieldName = fieldName
|
fieldData.FieldName = fieldName
|
||||||
|
|
||||||
|
// compatible with different nullable data format from sdk
|
||||||
|
if len(fieldData.GetValidData()) != 0 {
|
||||||
|
err := FillWithNullValue(fieldData, fieldSchema, int(it.upsertMsg.InsertMsg.NRows()))
|
||||||
|
if err != nil {
|
||||||
|
log.Info("unify null field data format failed", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: the most difficult part is to handle the merge progress of upsert and query result
|
// Two nullable data formats are supported:
|
||||||
// we need to enable merge logic on different length between upsertFieldData and it.insertFieldData
|
//
|
||||||
|
// COMPRESSED FORMAT (SDK format, before validateUtil.fillWithValue processing):
|
||||||
|
// Logical data: [1, null, 2]
|
||||||
|
// Storage: Data=[1, 2] + ValidData=[true, false, true]
|
||||||
|
// - Data array contains only non-null values (compressed)
|
||||||
|
// - ValidData array tracks null positions for all rows
|
||||||
|
//
|
||||||
|
// FULL FORMAT (Milvus internal format, after validateUtil.fillWithValue processing):
|
||||||
|
// Logical data: [1, null, 2]
|
||||||
|
// Storage: Data=[1, 0, 2] + ValidData=[true, false, true]
|
||||||
|
// - Data array contains values for all rows (nulls filled with zero/default)
|
||||||
|
// - ValidData array still tracks null positions
|
||||||
|
//
|
||||||
|
// Note: we will unify the nullable format to FULL FORMAT before executing the merge logic
|
||||||
insertIdxInUpsert := make([]int, 0)
|
insertIdxInUpsert := make([]int, 0)
|
||||||
updateIdxInUpsert := make([]int, 0)
|
updateIdxInUpsert := make([]int, 0)
|
||||||
// 1. split upsert data into insert and update by query result
|
// 1. split upsert data into insert and update by query result
|
||||||
@ -367,7 +389,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
|
return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping")
|
||||||
}
|
}
|
||||||
typeutil.AppendFieldDataWithNullData(it.insertFieldData, existFieldData, int64(existIndex), false)
|
typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex))
|
||||||
err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(baseIdx), int64(idx))
|
err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(baseIdx), int64(idx))
|
||||||
baseIdx += 1
|
baseIdx += 1
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -386,7 +408,7 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
|||||||
return lackOfFieldErr
|
return lackOfFieldErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the nullable or default value field is not set in upsert request, which means the len(upsertFieldData) < len(it.insertFieldData)
|
// if the nullable field has not passed in upsert request, which means the len(upsertFieldData) < len(it.insertFieldData)
|
||||||
// we need to generate the nullable field data before append as insert
|
// we need to generate the nullable field data before append as insert
|
||||||
insertWithNullField := make([]*schemapb.FieldData, 0)
|
insertWithNullField := make([]*schemapb.FieldData, 0)
|
||||||
upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (string, *schemapb.FieldData) {
|
upsertFieldMap := lo.SliceToMap(it.upsertMsg.InsertMsg.GetFieldsData(), func(field *schemapb.FieldData) (string, *schemapb.FieldData) {
|
||||||
@ -407,27 +429,27 @@ func (it *upsertTask) queryPreExecute(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for _, idx := range insertIdxInUpsert {
|
for _, idx := range insertIdxInUpsert {
|
||||||
typeutil.AppendFieldDataWithNullData(it.insertFieldData, insertWithNullField, int64(idx), true)
|
typeutil.AppendFieldData(it.insertFieldData, insertWithNullField, int64(idx))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. clean field data with valid data after merge upsert and query result
|
|
||||||
for _, fieldData := range it.insertFieldData {
|
for _, fieldData := range it.insertFieldData {
|
||||||
// Note: Since protobuf cannot correctly identify null values, zero values + valid data are used to identify null values,
|
if len(fieldData.GetValidData()) > 0 {
|
||||||
// therefore for field data obtained from query results, if the field is nullable, it needs to clean zero values
|
err := ToCompressedFormatNullable(fieldData)
|
||||||
if len(fieldData.GetValidData()) != 0 && getValidNumber(fieldData.GetValidData()) != len(fieldData.GetValidData()) {
|
|
||||||
err := ResetNullFieldData(fieldData)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Info("reset null field data failed", zap.Error(err))
|
log.Info("convert to compressed format nullable failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResetNullFieldData(field *schemapb.FieldData) error {
|
// ToCompressedFormatNullable converts the field data from full format nullable to compressed format nullable
|
||||||
|
func ToCompressedFormatNullable(field *schemapb.FieldData) error {
|
||||||
|
if getValidNumber(field.GetValidData()) == len(field.GetValidData()) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
switch field.Field.(type) {
|
switch field.Field.(type) {
|
||||||
case *schemapb.FieldData_Scalars:
|
case *schemapb.FieldData_Scalars:
|
||||||
switch sd := field.GetScalars().GetData().(type) {
|
switch sd := field.GetScalars().GetData().(type) {
|
||||||
@ -529,6 +551,20 @@ func ResetNullFieldData(field *schemapb.FieldData) error {
|
|||||||
sd.JsonData.Data = ret
|
sd.JsonData.Data = ret
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case *schemapb.ScalarField_ArrayData:
|
||||||
|
validRowNum := getValidNumber(field.GetValidData())
|
||||||
|
if validRowNum == 0 {
|
||||||
|
sd.ArrayData.Data = make([]*schemapb.ScalarField, 0)
|
||||||
|
} else {
|
||||||
|
ret := make([]*schemapb.ScalarField, 0, validRowNum)
|
||||||
|
for i, valid := range field.GetValidData() {
|
||||||
|
if valid {
|
||||||
|
ret = append(ret, sd.ArrayData.Data[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sd.ArrayData.Data = ret
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String()))
|
||||||
}
|
}
|
||||||
@ -540,6 +576,7 @@ func ResetNullFieldData(field *schemapb.FieldData) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenNullableFieldData generates nullable field data in FULL FORMAT
|
||||||
func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schemapb.FieldData, error) {
|
func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schemapb.FieldData, error) {
|
||||||
switch field.DataType {
|
switch field.DataType {
|
||||||
case schemapb.DataType_Bool:
|
case schemapb.DataType_Bool:
|
||||||
@ -668,6 +705,24 @@ func GenNullableFieldData(field *schemapb.FieldSchema, upsertIDSize int) (*schem
|
|||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
|
case schemapb.DataType_Array:
|
||||||
|
return &schemapb.FieldData{
|
||||||
|
FieldId: field.FieldID,
|
||||||
|
FieldName: field.Name,
|
||||||
|
Type: field.DataType,
|
||||||
|
IsDynamic: field.IsDynamic,
|
||||||
|
ValidData: make([]bool, upsertIDSize),
|
||||||
|
Field: &schemapb.FieldData_Scalars{
|
||||||
|
Scalars: &schemapb.ScalarField{
|
||||||
|
Data: &schemapb.ScalarField_ArrayData{
|
||||||
|
ArrayData: &schemapb.ArrayArray{
|
||||||
|
Data: make([]*schemapb.ScalarField, upsertIDSize),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined scalar data type:%s", field.DataType.String()))
|
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined scalar data type:%s", field.DataType.String()))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -340,12 +340,12 @@ func (v *validateUtil) fillWithValue(data []*schemapb.FieldData, schema *typeuti
|
|||||||
}
|
}
|
||||||
|
|
||||||
if fieldSchema.GetDefaultValue() == nil {
|
if fieldSchema.GetDefaultValue() == nil {
|
||||||
err = v.fillWithNullValue(field, fieldSchema, numRows)
|
err = FillWithNullValue(field, fieldSchema, numRows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = v.fillWithDefaultValue(field, fieldSchema, numRows)
|
err = FillWithDefaultValue(field, fieldSchema, numRows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -355,7 +355,7 @@ func (v *validateUtil) fillWithValue(data []*schemapb.FieldData, schema *typeuti
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *validateUtil) fillWithNullValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error {
|
func FillWithNullValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error {
|
||||||
err := nullutil.CheckValidData(field.GetValidData(), fieldSchema, numRows)
|
err := nullutil.CheckValidData(field.GetValidData(), fieldSchema, numRows)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -434,7 +434,7 @@ func (v *validateUtil) fillWithNullValue(field *schemapb.FieldData, fieldSchema
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *validateUtil) fillWithDefaultValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error {
|
func FillWithDefaultValue(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema, numRows int) error {
|
||||||
var err error
|
var err error
|
||||||
switch field.Field.(type) {
|
switch field.Field.(type) {
|
||||||
case *schemapb.FieldData_Scalars:
|
case *schemapb.FieldData_Scalars:
|
||||||
|
|||||||
@ -755,26 +755,6 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap
|
|||||||
}
|
}
|
||||||
|
|
||||||
func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) {
|
func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) {
|
||||||
return AppendFieldDataWithNullData(dst, src, idx, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AppendFieldData appends field data of specified index from src to dst
|
|
||||||
//
|
|
||||||
// Note: The field data in src may have two different nullable formats, so the caller
|
|
||||||
// must specify how to handle null values using the skipAppendNullData parameter.
|
|
||||||
//
|
|
||||||
// Two nullable data formats are supported:
|
|
||||||
//
|
|
||||||
// Case 1: Before validateUtil.fillWithValue processing
|
|
||||||
// Data: [1, null, 2] = [1, 2] + [true, false, true]
|
|
||||||
// Set skipAppendNullData = true to skip appending values to data array of dst
|
|
||||||
//
|
|
||||||
// Case 2: After validateUtil.fillWithValue processing
|
|
||||||
// Data: [1, null, 2] = [1, 0, 2] + [true, false, true]
|
|
||||||
// Set skipAppendNullData = false to append zero values to data array of dst
|
|
||||||
//
|
|
||||||
// TODO: Unify nullable format - SDK uses Case 1, Milvus uses Case 2
|
|
||||||
func AppendFieldDataWithNullData(dst, src []*schemapb.FieldData, idx int64, skipAppendNullData bool) (appendSize int64) {
|
|
||||||
dstMap := make(map[int64]*schemapb.FieldData)
|
dstMap := make(map[int64]*schemapb.FieldData)
|
||||||
for _, fieldData := range dst {
|
for _, fieldData := range dst {
|
||||||
if fieldData != nil {
|
if fieldData != nil {
|
||||||
@ -799,10 +779,6 @@ func AppendFieldDataWithNullData(dst, src []*schemapb.FieldData, idx int64, skip
|
|||||||
}
|
}
|
||||||
valid := fieldData.ValidData[idx]
|
valid := fieldData.ValidData[idx]
|
||||||
dstFieldData.ValidData = append(dstFieldData.ValidData, valid)
|
dstFieldData.ValidData = append(dstFieldData.ValidData, valid)
|
||||||
|
|
||||||
if !valid && skipAppendNullData {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
switch fieldType := fieldData.Field.(type) {
|
switch fieldType := fieldData.Field.(type) {
|
||||||
case *schemapb.FieldData_Scalars:
|
case *schemapb.FieldData_Scalars:
|
||||||
@ -1108,25 +1084,10 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
updateFieldIdx := updateIdx
|
|
||||||
// Update ValidData if present
|
// Update ValidData if present
|
||||||
if len(updateFieldData.GetValidData()) != 0 {
|
if len(updateFieldData.GetValidData()) != 0 {
|
||||||
if len(baseFieldData.GetValidData()) != 0 {
|
if len(baseFieldData.GetValidData()) != 0 {
|
||||||
baseFieldData.ValidData[baseIdx] = updateFieldData.ValidData[updateFieldIdx]
|
baseFieldData.ValidData[baseIdx] = updateFieldData.ValidData[updateIdx]
|
||||||
}
|
|
||||||
|
|
||||||
// update field data to null, only modify valid data
|
|
||||||
if !updateFieldData.ValidData[updateFieldIdx] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// for nullable field data, such as data=[1,1], valid_data=[true, false, true]
|
|
||||||
// should update the updateFieldIdx to the expected valid index
|
|
||||||
updateFieldIdx = 0
|
|
||||||
for _, validData := range updateFieldData.GetValidData()[:updateIdx] {
|
|
||||||
if validData {
|
|
||||||
updateFieldIdx += 1
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1140,43 +1101,58 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
|
|||||||
switch baseScalar.Data.(type) {
|
switch baseScalar.Data.(type) {
|
||||||
case *schemapb.ScalarField_BoolData:
|
case *schemapb.ScalarField_BoolData:
|
||||||
updateData := updateScalar.GetBoolData()
|
updateData := updateScalar.GetBoolData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
baseData := baseScalar.GetBoolData()
|
||||||
baseScalar.GetBoolData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
|
baseData.Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
case *schemapb.ScalarField_IntData:
|
case *schemapb.ScalarField_IntData:
|
||||||
updateData := updateScalar.GetIntData()
|
updateData := updateScalar.GetIntData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
baseData := baseScalar.GetIntData()
|
||||||
baseScalar.GetIntData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
|
baseData.Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
case *schemapb.ScalarField_LongData:
|
case *schemapb.ScalarField_LongData:
|
||||||
updateData := updateScalar.GetLongData()
|
updateData := updateScalar.GetLongData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
baseData := baseScalar.GetLongData()
|
||||||
baseScalar.GetLongData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
|
baseData.Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
case *schemapb.ScalarField_FloatData:
|
case *schemapb.ScalarField_FloatData:
|
||||||
updateData := updateScalar.GetFloatData()
|
updateData := updateScalar.GetFloatData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
baseData := baseScalar.GetFloatData()
|
||||||
baseScalar.GetFloatData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
|
baseData.Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
case *schemapb.ScalarField_DoubleData:
|
case *schemapb.ScalarField_DoubleData:
|
||||||
updateData := updateScalar.GetDoubleData()
|
updateData := updateScalar.GetDoubleData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
baseData := baseScalar.GetDoubleData()
|
||||||
baseScalar.GetDoubleData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
|
baseData.Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
case *schemapb.ScalarField_StringData:
|
case *schemapb.ScalarField_StringData:
|
||||||
updateData := updateScalar.GetStringData()
|
updateData := updateScalar.GetStringData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
baseData := baseScalar.GetStringData()
|
||||||
baseScalar.GetStringData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
|
baseData.Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
case *schemapb.ScalarField_ArrayData:
|
case *schemapb.ScalarField_ArrayData:
|
||||||
updateData := updateScalar.GetArrayData()
|
updateData := updateScalar.GetArrayData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
baseData := baseScalar.GetArrayData()
|
||||||
baseScalar.GetArrayData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
|
baseData.Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
case *schemapb.ScalarField_JsonData:
|
case *schemapb.ScalarField_JsonData:
|
||||||
updateData := updateScalar.GetJsonData()
|
updateData := updateScalar.GetJsonData()
|
||||||
baseData := baseScalar.GetJsonData()
|
baseData := baseScalar.GetJsonData()
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Data) {
|
if updateData != nil && baseData != nil &&
|
||||||
|
int(updateIdx) < len(updateData.Data) && int(baseIdx) < len(baseData.Data) {
|
||||||
if baseFieldData.GetIsDynamic() {
|
if baseFieldData.GetIsDynamic() {
|
||||||
// dynamic field is a json with only 1 level nested struct,
|
// dynamic field is a json with only 1 level nested struct,
|
||||||
// so we need to unmarshal and iterate updateData's key value, and update the baseData's key value
|
// so we need to unmarshal and iterate updateData's key value, and update the baseData's key value
|
||||||
@ -1186,7 +1162,7 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
|
|||||||
if err := json.Unmarshal(baseData.Data[baseIdx], &baseMap); err != nil {
|
if err := json.Unmarshal(baseData.Data[baseIdx], &baseMap); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal base json: %v", err)
|
return fmt.Errorf("failed to unmarshal base json: %v", err)
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(updateData.Data[updateFieldIdx], &updateMap); err != nil {
|
if err := json.Unmarshal(updateData.Data[updateIdx], &updateMap); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal update json: %v", err)
|
return fmt.Errorf("failed to unmarshal update json: %v", err)
|
||||||
}
|
}
|
||||||
// merge
|
// merge
|
||||||
@ -1200,7 +1176,7 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
|
|||||||
}
|
}
|
||||||
baseScalar.GetJsonData().Data[baseIdx] = newJSON
|
baseScalar.GetJsonData().Data[baseIdx] = newJSON
|
||||||
} else {
|
} else {
|
||||||
baseScalar.GetJsonData().Data[baseIdx] = updateData.Data[updateFieldIdx]
|
baseScalar.GetJsonData().Data[baseIdx] = updateData.Data[updateIdx]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@ -1216,73 +1192,71 @@ func UpdateFieldData(base, update []*schemapb.FieldData, baseIdx, updateIdx int6
|
|||||||
|
|
||||||
switch baseVector.Data.(type) {
|
switch baseVector.Data.(type) {
|
||||||
case *schemapb.VectorField_BinaryVector:
|
case *schemapb.VectorField_BinaryVector:
|
||||||
updateData := updateVector.GetBinaryVector()
|
|
||||||
if updateData != nil {
|
|
||||||
baseData := baseVector.GetBinaryVector()
|
baseData := baseVector.GetBinaryVector()
|
||||||
baseStartIdx := updateFieldIdx * (dim / 8)
|
updateData := updateVector.GetBinaryVector()
|
||||||
baseEndIdx := (updateFieldIdx + 1) * (dim / 8)
|
if baseData != nil && updateData != nil {
|
||||||
updateStartIdx := updateFieldIdx * (dim / 8)
|
baseStartIdx := baseIdx * (dim / 8)
|
||||||
updateEndIdx := (updateFieldIdx + 1) * (dim / 8)
|
baseEndIdx := (baseIdx + 1) * (dim / 8)
|
||||||
|
updateStartIdx := updateIdx * (dim / 8)
|
||||||
|
updateEndIdx := (updateIdx + 1) * (dim / 8)
|
||||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *schemapb.VectorField_FloatVector:
|
case *schemapb.VectorField_FloatVector:
|
||||||
updateData := updateVector.GetFloatVector()
|
|
||||||
if updateData != nil {
|
|
||||||
baseData := baseVector.GetFloatVector()
|
baseData := baseVector.GetFloatVector()
|
||||||
baseStartIdx := updateFieldIdx * dim
|
updateData := updateVector.GetFloatVector()
|
||||||
baseEndIdx := (updateFieldIdx + 1) * dim
|
if baseData != nil && updateData != nil {
|
||||||
updateStartIdx := updateFieldIdx * dim
|
baseStartIdx := baseIdx * dim
|
||||||
updateEndIdx := (updateFieldIdx + 1) * dim
|
baseEndIdx := (baseIdx + 1) * dim
|
||||||
|
updateStartIdx := updateIdx * dim
|
||||||
|
updateEndIdx := (updateIdx + 1) * dim
|
||||||
if int(updateEndIdx) <= len(updateData.Data) && int(baseEndIdx) <= len(baseData.Data) {
|
if int(updateEndIdx) <= len(updateData.Data) && int(baseEndIdx) <= len(baseData.Data) {
|
||||||
copy(baseData.Data[baseStartIdx:baseEndIdx], updateData.Data[updateStartIdx:updateEndIdx])
|
copy(baseData.Data[baseStartIdx:baseEndIdx], updateData.Data[updateStartIdx:updateEndIdx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *schemapb.VectorField_Float16Vector:
|
case *schemapb.VectorField_Float16Vector:
|
||||||
updateData := updateVector.GetFloat16Vector()
|
|
||||||
if updateData != nil {
|
|
||||||
baseData := baseVector.GetFloat16Vector()
|
baseData := baseVector.GetFloat16Vector()
|
||||||
baseStartIdx := updateFieldIdx * (dim * 2)
|
updateData := updateVector.GetFloat16Vector()
|
||||||
baseEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
if baseData != nil && updateData != nil {
|
||||||
updateStartIdx := updateFieldIdx * (dim * 2)
|
baseStartIdx := baseIdx * (dim * 2)
|
||||||
updateEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
baseEndIdx := (baseIdx + 1) * (dim * 2)
|
||||||
|
updateStartIdx := updateIdx * (dim * 2)
|
||||||
|
updateEndIdx := (updateIdx + 1) * (dim * 2)
|
||||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *schemapb.VectorField_Bfloat16Vector:
|
case *schemapb.VectorField_Bfloat16Vector:
|
||||||
updateData := updateVector.GetBfloat16Vector()
|
|
||||||
if updateData != nil {
|
|
||||||
baseData := baseVector.GetBfloat16Vector()
|
baseData := baseVector.GetBfloat16Vector()
|
||||||
baseStartIdx := updateFieldIdx * (dim * 2)
|
updateData := updateVector.GetBfloat16Vector()
|
||||||
baseEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
if baseData != nil && updateData != nil {
|
||||||
updateStartIdx := updateFieldIdx * (dim * 2)
|
baseStartIdx := baseIdx * (dim * 2)
|
||||||
updateEndIdx := (updateFieldIdx + 1) * (dim * 2)
|
baseEndIdx := (baseIdx + 1) * (dim * 2)
|
||||||
|
updateStartIdx := updateIdx * (dim * 2)
|
||||||
|
updateEndIdx := (updateIdx + 1) * (dim * 2)
|
||||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case *schemapb.VectorField_SparseFloatVector:
|
case *schemapb.VectorField_SparseFloatVector:
|
||||||
updateData := updateVector.GetSparseFloatVector()
|
|
||||||
if updateData != nil && int(updateFieldIdx) < len(updateData.Contents) {
|
|
||||||
baseData := baseVector.GetSparseFloatVector()
|
baseData := baseVector.GetSparseFloatVector()
|
||||||
if int(updateFieldIdx) < len(baseData.Contents) {
|
updateData := updateVector.GetSparseFloatVector()
|
||||||
baseData.Contents[updateFieldIdx] = updateData.Contents[updateFieldIdx]
|
if baseData != nil && updateData != nil && int(baseIdx) < len(baseData.Contents) && int(updateIdx) < len(updateData.Contents) {
|
||||||
|
baseData.Contents[baseIdx] = updateData.Contents[updateIdx]
|
||||||
// Update dimension if necessary
|
// Update dimension if necessary
|
||||||
if updateData.Dim > baseData.Dim {
|
if updateData.Dim > baseData.Dim {
|
||||||
baseData.Dim = updateData.Dim
|
baseData.Dim = updateData.Dim
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
case *schemapb.VectorField_Int8Vector:
|
case *schemapb.VectorField_Int8Vector:
|
||||||
updateData := updateVector.GetInt8Vector()
|
|
||||||
if updateData != nil {
|
|
||||||
baseData := baseVector.GetInt8Vector()
|
baseData := baseVector.GetInt8Vector()
|
||||||
baseStartIdx := updateFieldIdx * dim
|
updateData := updateVector.GetInt8Vector()
|
||||||
baseEndIdx := (updateFieldIdx + 1) * dim
|
if baseData != nil && updateData != nil {
|
||||||
updateStartIdx := updateFieldIdx * dim
|
baseStartIdx := baseIdx * dim
|
||||||
updateEndIdx := (updateFieldIdx + 1) * dim
|
baseEndIdx := (baseIdx + 1) * dim
|
||||||
|
updateStartIdx := updateIdx * dim
|
||||||
|
updateEndIdx := (updateIdx + 1) * dim
|
||||||
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
if int(updateEndIdx) <= len(updateData) && int(baseEndIdx) <= len(baseData) {
|
||||||
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
copy(baseData[baseStartIdx:baseEndIdx], updateData[updateStartIdx:updateEndIdx])
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -3931,3 +3931,103 @@ def parse_fmod(x: int, y: int) -> int:
|
|||||||
v = abs(x) % abs(y)
|
v = abs(x) % abs(y)
|
||||||
|
|
||||||
return v if x >= 0 else -v
|
return v if x >= 0 else -v
|
||||||
|
|
||||||
|
def gen_partial_row_data_by_schema(nb=ct.default_nb, schema=None, desired_field_names=None, num_fields=1,
|
||||||
|
start=0, random_pk=False, skip_field_names=[]):
|
||||||
|
"""
|
||||||
|
Generate row data that contains a subset of fields from the given schema.
|
||||||
|
Args:
|
||||||
|
schema: Collection schema or collection info dict. If None, uses default schema.
|
||||||
|
desired_field_names (list[str] | None): Explicit field names to include (intersected with eligible fields).
|
||||||
|
num_fields (int): Number of fields to include if desired_field_names is not provided. Defaults to 1.
|
||||||
|
start (int): Starting value for primary key fields when sequential values are needed.
|
||||||
|
random_pk (bool): Whether to generate random primary key values.
|
||||||
|
skip_field_names (list[str]): Field names to skip.
|
||||||
|
nb (int): Number of rows to generate. Defaults to 1.
|
||||||
|
Returns:
|
||||||
|
list[dict]: a list of rows.
|
||||||
|
Notes:
|
||||||
|
- Skips auto_id fields and function output fields.
|
||||||
|
- Primary INT64/VARCHAR fields get sequential values from `start` unless `random_pk=True`.
|
||||||
|
- Works with both schema dicts (from v2 client describe_collection) and ORM schema objects.
|
||||||
|
"""
|
||||||
|
if schema is None:
|
||||||
|
schema = gen_default_collection_schema()
|
||||||
|
func_output_fields = []
|
||||||
|
# Build list of eligible fields
|
||||||
|
if isinstance(schema, dict):
|
||||||
|
fields = schema.get('fields', [])
|
||||||
|
functions = schema.get('functions', [])
|
||||||
|
for func in functions:
|
||||||
|
output_field_names = func.get('output_field_names', [])
|
||||||
|
func_output_fields.extend(output_field_names)
|
||||||
|
func_output_fields = list(set(func_output_fields))
|
||||||
|
eligible_fields = []
|
||||||
|
for field in fields:
|
||||||
|
field_name = field.get('name', None)
|
||||||
|
if field.get('auto_id', False):
|
||||||
|
continue
|
||||||
|
if field_name in func_output_fields or field_name in skip_field_names:
|
||||||
|
continue
|
||||||
|
eligible_fields.append(field)
|
||||||
|
# Choose subset
|
||||||
|
if desired_field_names:
|
||||||
|
desired_set = set(desired_field_names)
|
||||||
|
chosen_fields = [f for f in eligible_fields if f.get('name') in desired_set]
|
||||||
|
else:
|
||||||
|
n = max(0, min(len(eligible_fields), num_fields if num_fields is not None else 1))
|
||||||
|
chosen_fields = eligible_fields[:n]
|
||||||
|
rows = []
|
||||||
|
curr_start = start
|
||||||
|
for _ in range(nb):
|
||||||
|
row = {}
|
||||||
|
for field in chosen_fields:
|
||||||
|
fname = field.get('name', None)
|
||||||
|
value = gen_data_by_collection_field(field, random_pk=random_pk)
|
||||||
|
# Override for PKs when not random
|
||||||
|
if not random_pk and field.get('is_primary', False) is True:
|
||||||
|
if field.get('type', None) == DataType.INT64:
|
||||||
|
value = curr_start
|
||||||
|
curr_start += 1
|
||||||
|
elif field.get('type', None) == DataType.VARCHAR:
|
||||||
|
value = str(curr_start)
|
||||||
|
curr_start += 1
|
||||||
|
row[fname] = value
|
||||||
|
rows.append(row)
|
||||||
|
return rows
|
||||||
|
# ORM schema path
|
||||||
|
fields = schema.fields
|
||||||
|
if hasattr(schema, "functions"):
|
||||||
|
functions = schema.functions
|
||||||
|
for func in functions:
|
||||||
|
func_output_fields.extend(func.output_field_names)
|
||||||
|
func_output_fields = list(set(func_output_fields))
|
||||||
|
eligible_fields = []
|
||||||
|
for field in fields:
|
||||||
|
if field.auto_id:
|
||||||
|
continue
|
||||||
|
if field.name in func_output_fields or field.name in skip_field_names:
|
||||||
|
continue
|
||||||
|
eligible_fields.append(field)
|
||||||
|
if desired_field_names:
|
||||||
|
desired_set = set(desired_field_names)
|
||||||
|
chosen_fields = [f for f in eligible_fields if f.name in desired_set]
|
||||||
|
else:
|
||||||
|
n = max(0, min(len(eligible_fields), num_fields if num_fields is not None else 1))
|
||||||
|
chosen_fields = eligible_fields[:n]
|
||||||
|
rows = []
|
||||||
|
curr_start = start
|
||||||
|
for _ in range(nb):
|
||||||
|
row = {}
|
||||||
|
for field in chosen_fields:
|
||||||
|
value = gen_data_by_collection_field(field, random_pk=random_pk)
|
||||||
|
if not random_pk and field.is_primary is True:
|
||||||
|
if field.dtype == DataType.INT64:
|
||||||
|
value = curr_start
|
||||||
|
curr_start += 1
|
||||||
|
elif field.dtype == DataType.VARCHAR:
|
||||||
|
value = str(curr_start)
|
||||||
|
curr_start += 1
|
||||||
|
row[field.name] = value
|
||||||
|
rows.append(row)
|
||||||
|
return rows
|
||||||
@ -0,0 +1,314 @@
|
|||||||
|
import pytest
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from common.common_type import CaseLabel, CheckTasks
|
||||||
|
from common import common_func as cf
|
||||||
|
from common import common_type as ct
|
||||||
|
from utils.util_log import test_log as log
|
||||||
|
from utils.util_pymilvus import *
|
||||||
|
from base.client_v2_base import TestMilvusClientV2Base
|
||||||
|
from pymilvus import DataType, FieldSchema, CollectionSchema
|
||||||
|
from sklearn import preprocessing
|
||||||
|
|
||||||
|
# Test parameters
|
||||||
|
default_nb = ct.default_nb
|
||||||
|
default_nq = ct.default_nq
|
||||||
|
default_limit = ct.default_limit
|
||||||
|
default_search_exp = "id >= 0"
|
||||||
|
exp_res = "exp_res"
|
||||||
|
default_primary_key_field_name = "id"
|
||||||
|
default_vector_field_name = "vector"
|
||||||
|
default_int32_field_name = ct.default_int32_field_name
|
||||||
|
|
||||||
|
|
||||||
|
class TestMilvusClientPartialUpdate(TestMilvusClientV2Base):
|
||||||
|
""" Test case of partial update functionality """
|
||||||
|
|
||||||
|
@pytest.mark.tags(CaseLabel.L0)
|
||||||
|
def test_partial_update_all_field_types(self):
|
||||||
|
"""
|
||||||
|
Test partial update functionality with all field types
|
||||||
|
1. Create collection with all data types
|
||||||
|
2. Insert initial data
|
||||||
|
3. Perform partial update for each field type
|
||||||
|
4. Verify all updates work correctly
|
||||||
|
"""
|
||||||
|
client = self._client()
|
||||||
|
dim = 64
|
||||||
|
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||||
|
|
||||||
|
# Create schema with all data types
|
||||||
|
schema = cf.gen_all_datatype_collection_schema(dim=dim)
|
||||||
|
|
||||||
|
# Create index parameters
|
||||||
|
index_params = client.prepare_index_params()
|
||||||
|
for i in range(len(schema.fields)):
|
||||||
|
field_name = schema.fields[i].name
|
||||||
|
print(f"field_name: {field_name}")
|
||||||
|
if field_name == "json_field":
|
||||||
|
index_params.add_index(field_name, index_type="AUTOINDEX",
|
||||||
|
params={"json_cast_type": "json"})
|
||||||
|
elif field_name == "text_sparse_emb":
|
||||||
|
index_params.add_index(field_name, index_type="AUTOINDEX", metric_type="BM25")
|
||||||
|
else:
|
||||||
|
index_params.add_index(field_name, index_type="AUTOINDEX")
|
||||||
|
|
||||||
|
# Create collection
|
||||||
|
client.create_collection(collection_name, default_dim, consistency_level="Strong", schema=schema, index_params=index_params)
|
||||||
|
|
||||||
|
# Load collection
|
||||||
|
self.load_collection(client, collection_name)
|
||||||
|
|
||||||
|
# Insert initial data
|
||||||
|
nb = 1000
|
||||||
|
rows = cf.gen_row_data_by_schema(nb=nb, schema=schema)
|
||||||
|
self.upsert(client, collection_name, rows, partial_update=True)
|
||||||
|
log.info(f"Inserted {nb} initial records")
|
||||||
|
|
||||||
|
primary_key_field_name = schema.fields[0].name
|
||||||
|
for i in range(len(schema.fields)):
|
||||||
|
update_field_name = schema.fields[i if i != 0 else 1].name
|
||||||
|
new_row = cf.gen_partial_row_data_by_schema(nb=nb, schema=schema,
|
||||||
|
desired_field_names=[primary_key_field_name, update_field_name])
|
||||||
|
client.upsert(collection_name, new_row, partial_update=True)
|
||||||
|
|
||||||
|
log.info("Partial update test for all field types passed successfully")
|
||||||
|
|
||||||
|
@pytest.mark.tags(CaseLabel.L0)
|
||||||
|
def test_partial_update_simple_demo(self):
|
||||||
|
"""
|
||||||
|
Test simple partial update demo with nullable fields
|
||||||
|
1. Create collection with explicit schema including nullable fields
|
||||||
|
2. Insert initial data with some null values
|
||||||
|
3. Perform partial updates with different field combinations
|
||||||
|
4. Verify partial update behavior preserves unchanged fields
|
||||||
|
"""
|
||||||
|
client = self._client()
|
||||||
|
dim = 3
|
||||||
|
collection_name = cf.gen_collection_name_by_testcase_name()
|
||||||
|
|
||||||
|
# Create schema with nullable fields
|
||||||
|
schema = self.create_schema(client, enable_dynamic_field=False)[0]
|
||||||
|
schema.add_field("id", DataType.INT64, is_primary=True, auto_id=False)
|
||||||
|
schema.add_field("vector", DataType.FLOAT_VECTOR, dim=dim)
|
||||||
|
schema.add_field("name", DataType.VARCHAR, max_length=100, nullable=True)
|
||||||
|
schema.add_field("price", DataType.FLOAT, nullable=True)
|
||||||
|
schema.add_field("category", DataType.VARCHAR, max_length=50, nullable=True)
|
||||||
|
|
||||||
|
# Create collection
|
||||||
|
self.create_collection(client, collection_name, schema=schema)
|
||||||
|
|
||||||
|
# Create index
|
||||||
|
index_params = self.prepare_index_params(client)[0]
|
||||||
|
index_params.add_index("vector", index_type="AUTOINDEX", metric_type="L2")
|
||||||
|
self.create_index(client, collection_name, index_params=index_params)
|
||||||
|
|
||||||
|
# Load collection
|
||||||
|
self.load_collection(client, collection_name)
|
||||||
|
|
||||||
|
# Insert initial data with some null values
|
||||||
|
initial_data = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
|
||||||
|
"name": "Product A",
|
||||||
|
"price": 100.0,
|
||||||
|
"category": "Electronics"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
|
||||||
|
"name": "Product B",
|
||||||
|
"price": None, # Null price
|
||||||
|
"category": "Home"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
|
||||||
|
"name": "Product C",
|
||||||
|
"price": None, # Null price
|
||||||
|
"category": "Books"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
self.upsert(client, collection_name, initial_data, partial_update=False)
|
||||||
|
log.info("Inserted initial data with null values")
|
||||||
|
|
||||||
|
# Verify initial state
|
||||||
|
results = self.query(client, collection_name, filter="id > 0", output_fields=["*"])[0]
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
initial_data_map = {data['id']: data for data in results}
|
||||||
|
assert initial_data_map[1]['name'] == "Product A"
|
||||||
|
assert initial_data_map[1]['price'] == 100.0
|
||||||
|
assert initial_data_map[1]['category'] == "Electronics"
|
||||||
|
assert initial_data_map[2]['name'] == "Product B"
|
||||||
|
assert initial_data_map[2]['price'] is None
|
||||||
|
assert initial_data_map[2]['category'] == "Home"
|
||||||
|
assert initial_data_map[3]['name'] == "Product C"
|
||||||
|
assert initial_data_map[3]['price'] is None
|
||||||
|
assert initial_data_map[3]['category'] == "Books"
|
||||||
|
|
||||||
|
log.info("Initial data verification passed")
|
||||||
|
|
||||||
|
# First partial update - update all fields
|
||||||
|
log.info("First partial update - updating all fields...")
|
||||||
|
first_update_data = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
|
||||||
|
"name": "Product A-Update",
|
||||||
|
"price": 111.1,
|
||||||
|
"category": "Electronics-Update"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
|
||||||
|
"name": "Product B-Update",
|
||||||
|
"price": 222.2,
|
||||||
|
"category": "Home-Update"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(),
|
||||||
|
"name": "Product C-Update",
|
||||||
|
"price": None, # Still null
|
||||||
|
"category": "Books-Update"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
self.upsert(client, collection_name, first_update_data, partial_update=True)
|
||||||
|
|
||||||
|
# Verify first update
|
||||||
|
results = self.query(client, collection_name, filter="id > 0", output_fields=["*"])[0]
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
first_update_map = {data['id']: data for data in results}
|
||||||
|
assert first_update_map[1]['name'] == "Product A-Update"
|
||||||
|
assert abs(first_update_map[1]['price'] - 111.1) < 0.001
|
||||||
|
assert first_update_map[1]['category'] == "Electronics-Update"
|
||||||
|
assert first_update_map[2]['name'] == "Product B-Update"
|
||||||
|
assert abs(first_update_map[2]['price'] - 222.2) < 0.001
|
||||||
|
assert first_update_map[2]['category'] == "Home-Update"
|
||||||
|
assert first_update_map[3]['name'] == "Product C-Update"
|
||||||
|
assert first_update_map[3]['price'] is None
|
||||||
|
assert first_update_map[3]['category'] == "Books-Update"
|
||||||
|
|
||||||
|
log.info("First partial update verification passed")
|
||||||
|
|
||||||
|
# Second partial update - update only specific fields
|
||||||
|
log.info("Second partial update - updating specific fields...")
|
||||||
|
second_update_data = [
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"name": "Product A-Update-Again",
|
||||||
|
"price": 1111.1,
|
||||||
|
"category": "Electronics-Update-Again"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"name": "Product B-Update-Again",
|
||||||
|
"price": None, # Set back to null
|
||||||
|
"category": "Home-Update-Again"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"name": "Product C-Update-Again",
|
||||||
|
"price": 3333.3, # Set price from null to value
|
||||||
|
"category": "Books-Update-Again"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
self.upsert(client, collection_name, second_update_data, partial_update=True)
|
||||||
|
|
||||||
|
# Verify second update
|
||||||
|
results = self.query(client, collection_name, filter="id > 0", output_fields=["*"])[0]
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
second_update_map = {data['id']: data for data in results}
|
||||||
|
|
||||||
|
# Verify ID 1: all fields updated
|
||||||
|
assert second_update_map[1]['name'] == "Product A-Update-Again"
|
||||||
|
assert abs(second_update_map[1]['price'] - 1111.1) < 0.001
|
||||||
|
assert second_update_map[1]['category'] == "Electronics-Update-Again"
|
||||||
|
|
||||||
|
# Verify ID 2: all fields updated, price set to null
|
||||||
|
assert second_update_map[2]['name'] == "Product B-Update-Again"
|
||||||
|
assert second_update_map[2]['price'] is None
|
||||||
|
assert second_update_map[2]['category'] == "Home-Update-Again"
|
||||||
|
|
||||||
|
# Verify ID 3: all fields updated, price set from null to value
|
||||||
|
assert second_update_map[3]['name'] == "Product C-Update-Again"
|
||||||
|
assert abs(second_update_map[3]['price'] - 3333.3) < 0.001
|
||||||
|
assert second_update_map[3]['category'] == "Books-Update-Again"
|
||||||
|
|
||||||
|
# Verify vector fields were preserved from first update (not updated in second update)
|
||||||
|
# Note: Vector comparison might be complex, so we just verify they exist
|
||||||
|
assert 'vector' in second_update_map[1]
|
||||||
|
assert 'vector' in second_update_map[2]
|
||||||
|
assert 'vector' in second_update_map[3]
|
||||||
|
|
||||||
|
log.info("Second partial update verification passed")
|
||||||
|
log.info("Simple partial update demo test completed successfully")
|
||||||
|
|
||||||
|
@pytest.mark.tags(CaseLabel.L0)
|
||||||
|
def test_milvus_client_partial_update_null_to_null(self):
|
||||||
|
"""
|
||||||
|
Target: test PU can successfully update a null to null
|
||||||
|
Method:
|
||||||
|
1. Create a collection, enable nullable fields
|
||||||
|
2. Insert default_nb rows to the collection
|
||||||
|
3. Partial Update the nullable field with null
|
||||||
|
4. Query the collection to check the value of nullable field
|
||||||
|
Expected: query should have correct value and number of entities
|
||||||
|
"""
|
||||||
|
# step 1: create collection with nullable fields
|
||||||
|
client = self._client()
|
||||||
|
schema = self.create_schema(client, enable_dynamic_field=False)[0]
|
||||||
|
schema.add_field(default_primary_key_field_name, DataType.INT64, is_primary=True, auto_id=False)
|
||||||
|
schema.add_field(default_vector_field_name, DataType.FLOAT_VECTOR, dim=default_dim)
|
||||||
|
schema.add_field(default_int32_field_name, DataType.INT32, nullable=True)
|
||||||
|
|
||||||
|
index_params = self.prepare_index_params(client)[0]
|
||||||
|
index_params.add_index(default_primary_key_field_name, index_type="AUTOINDEX")
|
||||||
|
index_params.add_index(default_vector_field_name, index_type="AUTOINDEX")
|
||||||
|
index_params.add_index(default_int32_field_name, index_type="AUTOINDEX")
|
||||||
|
|
||||||
|
collection_name = cf.gen_collection_name_by_testcase_name(module_index=1)
|
||||||
|
self.create_collection(client, collection_name, default_dim, schema=schema,
|
||||||
|
consistency_level="Strong", index_params=index_params)
|
||||||
|
|
||||||
|
# step 2: insert default_nb rows to the collection
|
||||||
|
rows = cf.gen_row_data_by_schema(nb=default_nb, schema=schema, skip_field_names=[default_int32_field_name])
|
||||||
|
self.upsert(client, collection_name, rows, partial_update=True)
|
||||||
|
|
||||||
|
# step 3: Partial Update the nullable field with null
|
||||||
|
new_row = cf.gen_partial_row_data_by_schema(
|
||||||
|
nb=default_nb,
|
||||||
|
schema=schema,
|
||||||
|
desired_field_names=[default_primary_key_field_name, default_int32_field_name],
|
||||||
|
start=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the nullable field to None
|
||||||
|
for data in new_row:
|
||||||
|
data[default_int32_field_name] = None
|
||||||
|
|
||||||
|
self.upsert(client, collection_name, new_row, partial_update=True)
|
||||||
|
|
||||||
|
# step 4: Query the collection to check the value of nullable field
|
||||||
|
result = self.query(client, collection_name, filter=default_search_exp,
|
||||||
|
check_task=CheckTasks.check_query_results,
|
||||||
|
output_fields=[default_int32_field_name],
|
||||||
|
check_items={exp_res: new_row,
|
||||||
|
"with_vec": True,
|
||||||
|
"pk_name": default_primary_key_field_name})[0]
|
||||||
|
assert len(result) == default_nb
|
||||||
|
|
||||||
|
# Verify that all nullable fields are indeed null
|
||||||
|
for data in result:
|
||||||
|
assert data[default_int32_field_name] is None, f"Expected null value for {default_int32_field_name}, got {data[default_int32_field_name]}"
|
||||||
|
|
||||||
|
log.info("Partial update null to null test completed successfully")
|
||||||
|
self.drop_collection(client, collection_name)
|
||||||
Loading…
x
Reference in New Issue
Block a user