From 7aed88113cb2bcd5e1a1c0d74067c03e14fb3e31 Mon Sep 17 00:00:00 2001 From: wei liu Date: Mon, 17 Nov 2025 21:35:40 +0800 Subject: [PATCH] enhance: Deduplicate primary keys in upsert request batch (#45249) issue: #44320 This change adds deduplication logic to handle duplicate primary keys within a single upsert batch, keeping the last occurrence of each primary key. Key changes: - Add DeduplicateFieldData function to remove duplicate PKs from field data, supporting both Int64 and VarChar primary keys - Refactor fillFieldPropertiesBySchema into two separate functions: validateFieldDataColumns for validation and fillFieldPropertiesOnly for property filling, improving code clarity and reusability - Integrate deduplication logic in upsertTask.PreExecute to automatically deduplicate data before processing - Add comprehensive unit tests for deduplication with various PK types (Int64, VarChar) and field types (scalar, vector) - Add Python integration tests to verify end-to-end behavior --------- Signed-off-by: Wei Liu --- internal/proxy/task_insert.go | 12 +- internal/proxy/task_upsert.go | 32 +- internal/proxy/task_upsert_test.go | 332 +++++++++++++++++ internal/proxy/util.go | 178 +++++++-- internal/proxy/util_test.go | 76 +++- .../test_milvus_client_upsert.py | 341 +++++++++++++++++- tests/python_client/testcases/test_insert.py | 2 +- 7 files changed, 911 insertions(+), 62 deletions(-) diff --git a/internal/proxy/task_insert.go b/internal/proxy/task_insert.go index 8dc9d92c2a..efe2a8a6e6 100644 --- a/internal/proxy/task_insert.go +++ b/internal/proxy/task_insert.go @@ -235,11 +235,15 @@ func (it *insertTask) PreExecute(ctx context.Context) error { return err } - // set field ID to insert field data - err = fillFieldPropertiesBySchema(it.insertMsg.GetFieldsData(), schema.CollectionSchema) + // Validate and set field ID to insert field data + err = validateFieldDataColumns(it.insertMsg.GetFieldsData(), schema) if err != nil { - log.Info("set fieldID to fieldData failed", - zap.Error(err)) + log.Info("validate field data columns failed", zap.Error(err)) + return err + } + err = fillFieldPropertiesOnly(it.insertMsg.GetFieldsData(), schema) + if err != nil { + log.Info("fill field properties failed", zap.Error(err)) return err } diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 9505f2b2c6..dc00d5bd3f 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -899,11 +899,15 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { return err } - // set field ID to insert field data - err = fillFieldPropertiesBySchema(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema.CollectionSchema) + // Validate and set field ID to insert field data + err = validateFieldDataColumns(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema) if err != nil { - log.Warn("insert set fieldID to fieldData failed when upsert", - zap.Error(err)) + log.Warn("validate field data columns failed when upsert", zap.Error(err)) + return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid) + } + err = fillFieldPropertiesOnly(it.upsertMsg.InsertMsg.GetFieldsData(), it.schema) + if err != nil { + log.Warn("fill field properties failed when upsert", zap.Error(err)) return merr.WrapErrAsInputErrorWhen(err, merr.ErrParameterInvalid) } @@ -1068,6 +1072,26 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { } } + // deduplicate upsert data to handle duplicate primary keys in the same batch + primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema) + if err != nil { + log.Warn("fail to get primary field schema", zap.Error(err)) + return err + } + deduplicatedFieldsData, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, it.req.GetFieldsData(), schema) + if err != nil { + log.Warn("fail to deduplicate upsert data", zap.Error(err)) + } + + // dedup won't decrease numOfRows to 0 + if newNumRows > 0 && newNumRows != it.req.NumRows { + log.Info("upsert data deduplicated", + zap.Uint32("original_num_rows", it.req.NumRows), + zap.Uint32("deduplicated_num_rows", newNumRows)) + it.req.FieldsData = deduplicatedFieldsData + it.req.NumRows = newNumRows + } + it.upsertMsg = &msgstream.UpsertMsg{ InsertMsg: &msgstream.InsertMsg{ InsertRequest: &msgpb.InsertRequest{ diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index f2bb10ad20..5f06c25395 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -1051,6 +1051,7 @@ func TestUpdateTask_PreExecute_InvalidNumRows(t *testing.T) { }, nil).Build() task := createTestUpdateTask() + task.req.FieldsData = []*schemapb.FieldData{} task.req.NumRows = 0 // Invalid num_rows err := task.PreExecute(context.Background()) @@ -1534,3 +1535,334 @@ func TestUpsertTask_PlanNamespace_AfterPreExecute(t *testing.T) { assert.Equal(t, *task.req.Namespace, *capturedPlan.Namespace) }) } + +func TestUpsertTask_Deduplicate_Int64PK(t *testing.T) { + // Test deduplication with Int64 primary key + primaryFieldSchema := &schemapb.FieldSchema{ + Name: "id", + FieldID: 100, + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + } + + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + primaryFieldSchema, + { + Name: "float_field", + FieldID: 101, + DataType: schemapb.DataType_Float, + }, + }, + } + schema := newSchemaInfo(collSchema) + + // Create field data with duplicate IDs: [1, 2, 3, 2, 1] + // Expected to keep last occurrence of each: [3, 2, 1] (indices 2, 3, 4) + fieldsData := []*schemapb.FieldData{ + { + FieldName: "id", + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 2, 1}, + }, + }, + }, + }, + }, + { + FieldName: "float_field", + Type: schemapb.DataType_Float, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1.1, 2.2, 3.3, 2.4, 1.5}, + }, + }, + }, + }, + }, + } + + deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema) + assert.NoError(t, err) + assert.Equal(t, uint32(3), newNumRows) + assert.Equal(t, 2, len(deduplicatedFields)) + + // Check deduplicated primary keys + pkField := deduplicatedFields[0] + pkData := pkField.GetScalars().GetLongData().GetData() + assert.Equal(t, 3, len(pkData)) + assert.Equal(t, []int64{3, 2, 1}, pkData) + + // Check corresponding float values (should be 3.3, 2.4, 1.5) + floatField := deduplicatedFields[1] + floatData := floatField.GetScalars().GetFloatData().GetData() + assert.Equal(t, 3, len(floatData)) + assert.Equal(t, []float32{3.3, 2.4, 1.5}, floatData) +} + +func TestUpsertTask_Deduplicate_VarCharPK(t *testing.T) { + // Test deduplication with VarChar primary key + primaryFieldSchema := &schemapb.FieldSchema{ + Name: "id", + FieldID: 100, + DataType: schemapb.DataType_VarChar, + IsPrimaryKey: true, + } + + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + primaryFieldSchema, + { + Name: "int_field", + FieldID: 101, + DataType: schemapb.DataType_Int64, + }, + }, + } + schema := newSchemaInfo(collSchema) + + // Create field data with duplicate IDs: ["a", "b", "c", "b", "a"] + // Expected to keep last occurrence of each: ["c", "b", "a"] (indices 2, 3, 4) + fieldsData := []*schemapb.FieldData{ + { + FieldName: "id", + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"a", "b", "c", "b", "a"}, + }, + }, + }, + }, + }, + { + FieldName: "int_field", + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{100, 200, 300, 201, 101}, + }, + }, + }, + }, + }, + } + + deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema) + assert.NoError(t, err) + assert.Equal(t, uint32(3), newNumRows) + assert.Equal(t, 2, len(deduplicatedFields)) + + // Check deduplicated primary keys + pkField := deduplicatedFields[0] + pkData := pkField.GetScalars().GetStringData().GetData() + assert.Equal(t, 3, len(pkData)) + assert.Equal(t, []string{"c", "b", "a"}, pkData) + + // Check corresponding int64 values (should be 300, 201, 101) + int64Field := deduplicatedFields[1] + int64Data := int64Field.GetScalars().GetLongData().GetData() + assert.Equal(t, 3, len(int64Data)) + assert.Equal(t, []int64{300, 201, 101}, int64Data) +} + +func TestUpsertTask_Deduplicate_NoDuplicates(t *testing.T) { + // Test with no duplicates - should return original data + primaryFieldSchema := &schemapb.FieldSchema{ + Name: "id", + FieldID: 100, + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + } + + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + primaryFieldSchema, + }, + } + schema := newSchemaInfo(collSchema) + + fieldsData := []*schemapb.FieldData{ + { + FieldName: "id", + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5}, + }, + }, + }, + }, + }, + } + + deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema) + assert.NoError(t, err) + assert.Equal(t, uint32(5), newNumRows) + assert.Equal(t, 1, len(deduplicatedFields)) + + // Should be unchanged + pkField := deduplicatedFields[0] + pkData := pkField.GetScalars().GetLongData().GetData() + assert.Equal(t, []int64{1, 2, 3, 4, 5}, pkData) +} + +func TestUpsertTask_Deduplicate_WithVector(t *testing.T) { + // Test deduplication with vector field + primaryFieldSchema := &schemapb.FieldSchema{ + Name: "id", + FieldID: 100, + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + } + + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + primaryFieldSchema, + { + Name: "vector", + FieldID: 101, + DataType: schemapb.DataType_FloatVector, + }, + }, + } + schema := newSchemaInfo(collSchema) + + dim := 4 + // Create field data with duplicate IDs: [1, 2, 1] + // Expected to keep indices [1, 2] (last occurrence of 2, last occurrence of 1) + fieldsData := []*schemapb.FieldData{ + { + FieldName: "id", + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 1}, + }, + }, + }, + }, + }, + { + FieldName: "vector", + Type: schemapb.DataType_FloatVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: int64(dim), + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: []float32{ + 1.0, 1.1, 1.2, 1.3, // vector for ID 1 (first occurrence) + 2.0, 2.1, 2.2, 2.3, // vector for ID 2 + 1.4, 1.5, 1.6, 1.7, // vector for ID 1 (second occurrence - keep this) + }, + }, + }, + }, + }, + }, + } + + deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema) + assert.NoError(t, err) + assert.Equal(t, uint32(2), newNumRows) + assert.Equal(t, 2, len(deduplicatedFields)) + + // Check deduplicated primary keys + pkField := deduplicatedFields[0] + pkData := pkField.GetScalars().GetLongData().GetData() + assert.Equal(t, 2, len(pkData)) + assert.Equal(t, []int64{2, 1}, pkData) + + // Check corresponding vector (should keep vectors for ID 2 and ID 1's last occurrence) + vectorField := deduplicatedFields[1] + vectorData := vectorField.GetVectors().GetFloatVector().GetData() + assert.Equal(t, 8, len(vectorData)) // 2 vectors * 4 dimensions + expectedVector := []float32{ + 2.0, 2.1, 2.2, 2.3, // vector for ID 2 + 1.4, 1.5, 1.6, 1.7, // vector for ID 1 (last occurrence) + } + assert.Equal(t, expectedVector, vectorData) +} + +func TestUpsertTask_Deduplicate_EmptyData(t *testing.T) { + // Test with empty data + primaryFieldSchema := &schemapb.FieldSchema{ + Name: "id", + FieldID: 100, + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + } + + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + primaryFieldSchema, + }, + } + schema := newSchemaInfo(collSchema) + + fieldsData := []*schemapb.FieldData{} + + deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema) + assert.NoError(t, err) + assert.Equal(t, uint32(0), newNumRows) + assert.Equal(t, 0, len(deduplicatedFields)) +} + +func TestUpsertTask_Deduplicate_MissingPrimaryKey(t *testing.T) { + // Test with missing primary key field + primaryFieldSchema := &schemapb.FieldSchema{ + Name: "id", + FieldID: 100, + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + } + + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + primaryFieldSchema, + { + Name: "other_field", + FieldID: 101, + DataType: schemapb.DataType_Float, + }, + }, + } + schema := newSchemaInfo(collSchema) + + fieldsData := []*schemapb.FieldData{ + { + FieldName: "other_field", + Type: schemapb.DataType_Float, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_FloatData{ + FloatData: &schemapb.FloatArray{ + Data: []float32{1.1, 2.2}, + }, + }, + }, + }, + }, + } + + _, _, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema) + assert.Error(t, err) + // validateFieldDataColumns will fail first due to column count mismatch + // or the function will fail when trying to find primary key + assert.True(t, err != nil) +} diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 0f2453047b..1713c69cb0 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1046,6 +1046,103 @@ func parsePrimaryFieldData2IDs(fieldData *schemapb.FieldData) (*schemapb.IDs, er return primaryData, nil } +// findLastOccurrenceIndices finds indices of last occurrences for each unique ID +func findLastOccurrenceIndices[T comparable](ids []T) []int { + lastOccurrence := make(map[T]int, len(ids)) + for idx, id := range ids { + lastOccurrence[id] = idx + } + + keepIndices := make([]int, 0, len(lastOccurrence)) + for idx, id := range ids { + if lastOccurrence[id] == idx { + keepIndices = append(keepIndices, idx) + } + } + return keepIndices +} + +// DeduplicateFieldData removes duplicate primary keys from field data, +// keeping the last occurrence of each ID +func DeduplicateFieldData(primaryFieldSchema *schemapb.FieldSchema, fieldsData []*schemapb.FieldData, schema *schemaInfo) ([]*schemapb.FieldData, uint32, error) { + if len(fieldsData) == 0 { + return fieldsData, 0, nil + } + + if err := fillFieldPropertiesOnly(fieldsData, schema); err != nil { + return nil, 0, err + } + + // find primary field data + var primaryFieldData *schemapb.FieldData + for _, field := range fieldsData { + if field.GetFieldName() == primaryFieldSchema.GetName() { + primaryFieldData = field + break + } + } + + if primaryFieldData == nil { + return nil, 0, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldSchema.GetName())) + } + + // get row count + var numRows int + switch primaryFieldData.Field.(type) { + case *schemapb.FieldData_Scalars: + scalarField := primaryFieldData.GetScalars() + switch scalarField.Data.(type) { + case *schemapb.ScalarField_LongData: + numRows = len(scalarField.GetLongData().GetData()) + case *schemapb.ScalarField_StringData: + numRows = len(scalarField.GetStringData().GetData()) + default: + return nil, 0, merr.WrapErrParameterInvalidMsg("unsupported primary key type") + } + default: + return nil, 0, merr.WrapErrParameterInvalidMsg("primary field must be scalar type") + } + + if numRows == 0 { + return fieldsData, 0, nil + } + + // build map to track last occurrence of each primary key + var keepIndices []int + switch primaryFieldData.Field.(type) { + case *schemapb.FieldData_Scalars: + scalarField := primaryFieldData.GetScalars() + switch scalarField.Data.(type) { + case *schemapb.ScalarField_LongData: + // for Int64 primary keys + intIDs := scalarField.GetLongData().GetData() + keepIndices = findLastOccurrenceIndices(intIDs) + + case *schemapb.ScalarField_StringData: + // for VarChar primary keys + strIDs := scalarField.GetStringData().GetData() + keepIndices = findLastOccurrenceIndices(strIDs) + } + } + + // if no duplicates found, return original data + if len(keepIndices) == numRows { + return fieldsData, uint32(numRows), nil + } + + log.Info("duplicate primary keys detected in upsert request, deduplicating", + zap.Int("original_rows", numRows), + zap.Int("deduplicated_rows", len(keepIndices))) + + // use typeutil.AppendFieldData to rebuild field data with deduplicated rows + result := typeutil.PrepareResultFieldData(fieldsData, int64(len(keepIndices))) + for _, idx := range keepIndices { + typeutil.AppendFieldData(result, fieldsData, int64(idx)) + } + + return result, uint32(len(keepIndices)), nil +} + // autoGenPrimaryFieldData generate primary data when autoID == true func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}) (*schemapb.FieldData, error) { var fieldData schemapb.FieldData @@ -1105,52 +1202,34 @@ func autoGenDynamicFieldData(data [][]byte) *schemapb.FieldData { } } -// fillFieldPropertiesBySchema set fieldID to fieldData according FieldSchemas -func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error { - fieldName2Schema := make(map[string]*schemapb.FieldSchema) - +// validateFieldDataColumns validates that all required fields are present and no unknown fields exist. +// It checks: +// 1. The number of columns matches the expected count (excluding BM25 output fields) +// 2. All field names exist in the schema +// Returns detailed error message listing expected and provided fields if validation fails. +func validateFieldDataColumns(columns []*schemapb.FieldData, schema *schemaInfo) error { expectColumnNum := 0 + + // Count expected columns for _, field := range schema.GetFields() { - fieldName2Schema[field.Name] = field - if !typeutil.IsBM25FunctionOutputField(field, schema) { + if !typeutil.IsBM25FunctionOutputField(field, schema.CollectionSchema) { expectColumnNum++ } } - for _, structField := range schema.GetStructArrayFields() { - for _, field := range structField.GetFields() { - fieldName2Schema[field.Name] = field - expectColumnNum++ - } + expectColumnNum += len(structField.GetFields()) } + // Validate column count if len(columns) != expectColumnNum { return fmt.Errorf("len(columns) mismatch the expectColumnNum, expectColumnNum: %d, len(columns): %d", expectColumnNum, len(columns)) } + // Validate field existence using schemaHelper for _, fieldData := range columns { - if fieldSchema, ok := fieldName2Schema[fieldData.FieldName]; ok { - fieldData.FieldId = fieldSchema.FieldID - fieldData.Type = fieldSchema.DataType - - // Set the ElementType because it may not be set in the insert request. - if fieldData.Type == schemapb.DataType_Array { - fd, ok := fieldData.Field.(*schemapb.FieldData_Scalars) - if !ok || fd.Scalars.GetArrayData() == nil { - return fmt.Errorf("field convert FieldData_Scalars fail in fieldData, fieldName: %s,"+ - " collectionName:%s", fieldData.FieldName, schema.Name) - } - fd.Scalars.GetArrayData().ElementType = fieldSchema.ElementType - } else if fieldData.Type == schemapb.DataType_ArrayOfVector { - fd, ok := fieldData.Field.(*schemapb.FieldData_Vectors) - if !ok || fd.Vectors.GetVectorArray() == nil { - return fmt.Errorf("field convert FieldData_Vectors fail in fieldData, fieldName: %s,"+ - " collectionName:%s", fieldData.FieldName, schema.Name) - } - fd.Vectors.GetVectorArray().ElementType = fieldSchema.ElementType - } - } else { + _, err := schema.schemaHelper.GetFieldFromNameDefaultJSON(fieldData.FieldName) + if err != nil { return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName) } } @@ -1158,6 +1237,41 @@ func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb return nil } +// fillFieldPropertiesOnly fills field properties (FieldId, Type, ElementType) from schema. +// It assumes that columns have been validated and does not perform validation. +// Use validateFieldDataColumns before calling this function if validation is needed. +func fillFieldPropertiesOnly(columns []*schemapb.FieldData, schema *schemaInfo) error { + for _, fieldData := range columns { + // Use schemaHelper to get field schema, automatically handles dynamic fields + fieldSchema, err := schema.schemaHelper.GetFieldFromNameDefaultJSON(fieldData.FieldName) + if err != nil { + return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName) + } + + fieldData.FieldId = fieldSchema.FieldID + fieldData.Type = fieldSchema.DataType + + // Set the ElementType because it may not be set in the insert request. + if fieldData.Type == schemapb.DataType_Array { + fd, ok := fieldData.Field.(*schemapb.FieldData_Scalars) + if !ok || fd.Scalars.GetArrayData() == nil { + return fmt.Errorf("field convert FieldData_Scalars fail in fieldData, fieldName: %s, collectionName: %s", + fieldData.FieldName, schema.Name) + } + fd.Scalars.GetArrayData().ElementType = fieldSchema.ElementType + } else if fieldData.Type == schemapb.DataType_ArrayOfVector { + fd, ok := fieldData.Field.(*schemapb.FieldData_Vectors) + if !ok || fd.Vectors.GetVectorArray() == nil { + return fmt.Errorf("field convert FieldData_Vectors fail in fieldData, fieldName: %s, collectionName: %s", + fieldData.FieldName, schema.Name) + } + fd.Vectors.GetVectorArray().ElementType = fieldSchema.ElementType + } + } + + return nil +} + func ValidateUsername(username string) error { username = strings.TrimSpace(username) diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 128225d644..1424bc3588 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -606,28 +606,64 @@ func TestValidateMultipleVectorFields(t *testing.T) { } func TestFillFieldIDBySchema(t *testing.T) { - schema := &schemapb.CollectionSchema{} - columns := []*schemapb.FieldData{ - { - FieldName: "TestFillFieldIDBySchema", - }, - } - - // length mismatch - assert.Error(t, fillFieldPropertiesBySchema(columns, schema)) - schema = &schemapb.CollectionSchema{ - Fields: []*schemapb.FieldSchema{ + t.Run("column count mismatch", func(t *testing.T) { + collSchema := &schemapb.CollectionSchema{} + schema := newSchemaInfo(collSchema) + columns := []*schemapb.FieldData{ { - Name: "TestFillFieldIDBySchema", - DataType: schemapb.DataType_Int64, - FieldID: 1, + FieldName: "TestFillFieldIDBySchema", }, - }, - } - assert.NoError(t, fillFieldPropertiesBySchema(columns, schema)) - assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName) - assert.Equal(t, schemapb.DataType_Int64, columns[0].Type) - assert.Equal(t, int64(1), columns[0].FieldId) + } + // Validation should fail due to column count mismatch + assert.Error(t, validateFieldDataColumns(columns, schema)) + }) + + t.Run("successful validation and fill", func(t *testing.T) { + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "TestFillFieldIDBySchema", + DataType: schemapb.DataType_Int64, + FieldID: 1, + }, + }, + } + schema := newSchemaInfo(collSchema) + columns := []*schemapb.FieldData{ + { + FieldName: "TestFillFieldIDBySchema", + }, + } + // Validation should succeed + assert.NoError(t, validateFieldDataColumns(columns, schema)) + // Fill properties should succeed + assert.NoError(t, fillFieldPropertiesOnly(columns, schema)) + assert.Equal(t, "TestFillFieldIDBySchema", columns[0].FieldName) + assert.Equal(t, schemapb.DataType_Int64, columns[0].Type) + assert.Equal(t, int64(1), columns[0].FieldId) + }) + + t.Run("field not in schema", func(t *testing.T) { + collSchema := &schemapb.CollectionSchema{ + Fields: []*schemapb.FieldSchema{ + { + Name: "FieldA", + DataType: schemapb.DataType_Int64, + FieldID: 1, + }, + }, + } + schema := newSchemaInfo(collSchema) + columns := []*schemapb.FieldData{ + { + FieldName: "FieldB", + }, + } + // Validation should fail because FieldB is not in schema + err := validateFieldDataColumns(columns, schema) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not exist in collection schema") + }) } func TestValidateUsername(t *testing.T) { diff --git a/tests/python_client/milvus_client/test_milvus_client_upsert.py b/tests/python_client/milvus_client/test_milvus_client_upsert.py index 8c896740f7..a670f838c4 100644 --- a/tests/python_client/milvus_client/test_milvus_client_upsert.py +++ b/tests/python_client/milvus_client/test_milvus_client_upsert.py @@ -550,4 +550,343 @@ class TestMilvusClientUpsertValid(TestMilvusClientV2Base): self.release_partitions(client, collection_name, partition_name) self.drop_partition(client, collection_name, partition_name) if self.has_collection(client, collection_name)[0]: - self.drop_collection(client, collection_name) \ No newline at end of file + self.drop_collection(client, collection_name) + + +class TestMilvusClientUpsertDedup(TestMilvusClientV2Base): + """Test case for upsert deduplication functionality""" + + @pytest.fixture(scope="function", params=["COSINE", "L2"]) + def metric_type(self, request): + yield request.param + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_upsert_dedup_int64_pk(self): + """ + target: test upsert with duplicate int64 primary keys in same batch + method: + 1. create collection with int64 primary key + 2. upsert data with duplicate primary keys [1, 2, 3, 2, 1] + 3. query to verify only last occurrence is kept + expected: only 3 unique records exist, with data from last occurrence + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + + # 1. create collection + self.create_collection(client, collection_name, default_dim, consistency_level="Strong") + + # 2. upsert data with duplicate PKs: [1, 2, 3, 2, 1] + # Expected: keep last occurrence -> [3, 2, 1] at indices [2, 3, 4] + rng = np.random.default_rng(seed=19530) + rows = [ + {default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 1.0, default_string_field_name: "str_1_first"}, + {default_primary_key_field_name: 2, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 2.0, default_string_field_name: "str_2_first"}, + {default_primary_key_field_name: 3, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 3.0, default_string_field_name: "str_3"}, + {default_primary_key_field_name: 2, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 2.1, default_string_field_name: "str_2_last"}, + {default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 1.1, default_string_field_name: "str_1_last"}, + ] + + results = self.upsert(client, collection_name, rows)[0] + # After deduplication, should only have 3 records + assert results['upsert_count'] == 3 + + # 3. query to verify deduplication - should have only 3 unique records + query_results = self.query(client, collection_name, filter="id >= 0")[0] + assert len(query_results) == 3 + + # Verify that last occurrence data is kept + id_to_data = {item['id']: item for item in query_results} + assert 1 in id_to_data + assert 2 in id_to_data + assert 3 in id_to_data + + # Check that data from last occurrence is preserved + assert id_to_data[1]['float'] == 1.1 + assert id_to_data[1]['varchar'] == "str_1_last" + assert id_to_data[2]['float'] == 2.1 + assert id_to_data[2]['varchar'] == "str_2_last" + assert id_to_data[3]['float'] == 3.0 + assert id_to_data[3]['varchar'] == "str_3" + + self.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_upsert_dedup_varchar_pk(self): + """ + target: test upsert with duplicate varchar primary keys in same batch + method: + 1. create collection with varchar primary key + 2. upsert data with duplicate primary keys ["a", "b", "c", "b", "a"] + 3. query to verify only last occurrence is kept + expected: only 3 unique records exist, with data from last occurrence + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + + # 1. create collection with varchar primary key + schema = self.create_schema(client, enable_dynamic_field=True)[0] + schema.add_field("id", DataType.VARCHAR, max_length=64, is_primary=True, auto_id=False) + schema.add_field(default_vector_field_name, DataType.FLOAT_VECTOR, dim=default_dim) + schema.add_field("age", DataType.INT64) + index_params = self.prepare_index_params(client)[0] + index_params.add_index(default_vector_field_name, metric_type="COSINE") + self.create_collection(client, collection_name, default_dim, schema=schema, + index_params=index_params, consistency_level="Strong") + + # 2. upsert data with duplicate PKs: ["a", "b", "c", "b", "a"] + # Expected: keep last occurrence -> ["c", "b", "a"] at indices [2, 3, 4] + rng = np.random.default_rng(seed=19530) + rows = [ + {"id": "a", default_vector_field_name: list(rng.random((1, default_dim))[0]), + "age": 10}, + {"id": "b", default_vector_field_name: list(rng.random((1, default_dim))[0]), + "age": 20}, + {"id": "c", default_vector_field_name: list(rng.random((1, default_dim))[0]), + "age": 30}, + {"id": "b", default_vector_field_name: list(rng.random((1, default_dim))[0]), + "age": 21}, + {"id": "a", default_vector_field_name: list(rng.random((1, default_dim))[0]), + "age": 11}, + ] + + results = self.upsert(client, collection_name, rows)[0] + # After deduplication, should only have 3 records + assert results['upsert_count'] == 3 + + # 3. query to verify deduplication + query_results = self.query(client, collection_name, filter='id in ["a", "b", "c"]')[0] + assert len(query_results) == 3 + + # Verify that last occurrence data is kept + id_to_data = {item['id']: item for item in query_results} + assert "a" in id_to_data + assert "b" in id_to_data + assert "c" in id_to_data + + # Check that data from last occurrence is preserved + assert id_to_data["a"]["age"] == 11 + assert id_to_data["b"]["age"] == 21 + assert id_to_data["c"]["age"] == 30 + + self.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_upsert_dedup_all_duplicates(self): + """ + target: test upsert when all records have same primary key + method: + 1. create collection + 2. upsert 5 records with same primary key + 3. query to verify only 1 record exists + expected: only 1 record exists with data from last occurrence + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + + # 1. create collection + self.create_collection(client, collection_name, default_dim, consistency_level="Strong") + + # 2. upsert data where all have same PK (id=1) + rng = np.random.default_rng(seed=19530) + rows = [ + {default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: i * 1.0, default_string_field_name: f"version_{i}"} + for i in range(5) + ] + + results = self.upsert(client, collection_name, rows)[0] + # After deduplication, should only have 1 record + assert results['upsert_count'] == 1 + + # 3. query to verify only 1 record exists + query_results = self.query(client, collection_name, filter="id == 1")[0] + assert len(query_results) == 1 + + # Verify it's the last occurrence (i=4) + assert query_results[0]['float'] == 4.0 + assert query_results[0]['varchar'] == "version_4" + + self.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_upsert_dedup_no_duplicates(self): + """ + target: test upsert with no duplicate primary keys + method: + 1. create collection + 2. upsert data with unique primary keys + 3. query to verify all records exist + expected: all records exist as-is + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + + # 1. create collection + self.create_collection(client, collection_name, default_dim, consistency_level="Strong") + + # 2. upsert data with unique PKs + rng = np.random.default_rng(seed=19530) + nb = 10 + rows = [ + {default_primary_key_field_name: i, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: i * 1.0, default_string_field_name: str(i)} + for i in range(nb) + ] + + results = self.upsert(client, collection_name, rows)[0] + # No deduplication should occur + assert results['upsert_count'] == nb + + # 3. query to verify all records exist + query_results = self.query(client, collection_name, filter=f"id >= 0")[0] + assert len(query_results) == nb + + self.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L2) + def test_milvus_client_upsert_dedup_large_batch(self): + """ + target: test upsert deduplication with large batch + method: + 1. create collection + 2. upsert large batch with 50% duplicate primary keys + 3. query to verify correct number of records + expected: only unique records exist + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + + # 1. create collection + self.create_collection(client, collection_name, default_dim, consistency_level="Strong") + + # 2. upsert large batch where each ID appears twice + rng = np.random.default_rng(seed=19530) + nb = 500 + unique_ids = nb // 2 # 250 unique IDs + + rows = [] + for i in range(nb): + pk = i % unique_ids # This creates duplicates: 0,1,2...249,0,1,2...249 + rows.append({ + default_primary_key_field_name: pk, + default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: float(i), # Different value for each row + default_string_field_name: f"batch_{i}" + }) + + results = self.upsert(client, collection_name, rows)[0] + # After deduplication, should only have unique_ids records + assert results['upsert_count'] == unique_ids + + # 3. query to verify correct number of records + query_results = self.query(client, collection_name, filter=f"id >= 0", limit=1000)[0] + assert len(query_results) == unique_ids + + # Verify that last occurrence is kept (should have higher float values) + for item in query_results: + pk = item['id'] + # Last occurrence of pk is at index (pk + unique_ids) + expected_float = float(pk + unique_ids) + assert item['float'] == expected_float + assert item['varchar'] == f"batch_{pk + unique_ids}" + + self.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_upsert_dedup_with_partition(self): + """ + target: test upsert deduplication works correctly with partitions + method: + 1. create collection with partition + 2. upsert data with duplicates to specific partition + 3. query to verify deduplication in partition + expected: deduplication works within partition + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + partition_name = cf.gen_unique_str("partition") + + # 1. create collection and partition + self.create_collection(client, collection_name, default_dim, consistency_level="Strong") + self.create_partition(client, collection_name, partition_name) + + # 2. upsert data with duplicates to partition + rng = np.random.default_rng(seed=19530) + rows = [ + {default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 1.0, default_string_field_name: "first"}, + {default_primary_key_field_name: 2, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 2.0, default_string_field_name: "unique"}, + {default_primary_key_field_name: 1, default_vector_field_name: list(rng.random((1, default_dim))[0]), + default_float_field_name: 1.1, default_string_field_name: "last"}, + ] + + results = self.upsert(client, collection_name, rows, partition_name=partition_name)[0] + assert results['upsert_count'] == 2 + + # 3. query partition to verify deduplication + query_results = self.query(client, collection_name, filter="id >= 0", + partition_names=[partition_name])[0] + assert len(query_results) == 2 + + # Verify correct data + id_to_data = {item['id']: item for item in query_results} + assert id_to_data[1]['float'] == 1.1 + assert id_to_data[1]['varchar'] == "last" + assert id_to_data[2]['float'] == 2.0 + assert id_to_data[2]['varchar'] == "unique" + + self.drop_collection(client, collection_name) + + @pytest.mark.tags(CaseLabel.L1) + def test_milvus_client_upsert_dedup_with_vectors(self): + """ + target: test upsert deduplication preserves correct vector data + method: + 1. create collection + 2. upsert data with duplicate PKs but different vectors + 3. search to verify correct vector is preserved + expected: vector from last occurrence is preserved + """ + client = self._client() + collection_name = cf.gen_collection_name_by_testcase_name() + + # 1. create collection + self.create_collection(client, collection_name, default_dim, consistency_level="Strong") + + # 2. upsert data with duplicate PK=1 but different vectors + # Create distinctly different vectors for easy verification + first_vector = [1.0] * default_dim # All 1.0 + last_vector = [2.0] * default_dim # All 2.0 + + rows = [ + {default_primary_key_field_name: 1, default_vector_field_name: first_vector, + default_float_field_name: 1.0, default_string_field_name: "first"}, + {default_primary_key_field_name: 2, default_vector_field_name: [0.5] * default_dim, + default_float_field_name: 2.0, default_string_field_name: "unique"}, + {default_primary_key_field_name: 1, default_vector_field_name: last_vector, + default_float_field_name: 1.1, default_string_field_name: "last"}, + ] + + results = self.upsert(client, collection_name, rows)[0] + assert results['upsert_count'] == 2 + + # 3. query to get vector data + query_results = self.query(client, collection_name, filter="id == 1", + output_fields=["id", "vector", "float", "varchar"])[0] + assert len(query_results) == 1 + + # Verify it's the last occurrence with last_vector + result = query_results[0] + assert result['float'] == 1.1 + assert result['varchar'] == "last" + # Vector should be last_vector (all 2.0) + assert all(abs(v - 2.0) < 0.001 for v in result['vector']) + + self.drop_collection(client, collection_name) \ No newline at end of file diff --git a/tests/python_client/testcases/test_insert.py b/tests/python_client/testcases/test_insert.py index f0627fc84b..93c635c3ed 100644 --- a/tests/python_client/testcases/test_insert.py +++ b/tests/python_client/testcases/test_insert.py @@ -2077,7 +2077,7 @@ class TestUpsertInvalid(TestcaseBase): log.debug(f"dirty_i: {dirty_i}") for i in range(len(data)): if data[i][dirty_i].__class__ is int: - tmp = data[i][0] + tmp = data[i][dirty_i] data[i][dirty_i] = "iamstring" error = {ct.err_code: 999, ct.err_msg: "The Input data type is inconsistent with defined schema"} collection_w.upsert(data=data, check_task=CheckTasks.err_res, check_items=error)