From d3c95eaa77b9f7323891e7638d717dc7fbee5435 Mon Sep 17 00:00:00 2001 From: wei liu Date: Tue, 19 Aug 2025 11:15:45 +0800 Subject: [PATCH] enhance: Support partial field updates with upsert API (#42877) issue: #29735 Implement partial field update functionality for upsert operations, supporting scalar, vector, and dynamic JSON fields without requiring all collection fields. Changes: - Add queryPreExecute to retrieve existing records before upsert - Implement UpdateFieldData function for merging data - Add IDsChecker utility for efficient primary key lookups - Fix JSON data creation in tests using proper map marshaling - Add test cases for partial updates of different field types Signed-off-by: Wei Liu --- client/go.sum | 2 - client/milvusclient/write_options.go | 8 + go.mod | 11 +- .../proxy/httpserver/handler_v1.go | 10 +- .../proxy/httpserver/handler_v2.go | 9 +- .../distributed/proxy/httpserver/request.go | 2 + .../proxy/httpserver/request_v2.go | 1 + .../distributed/proxy/httpserver/utils.go | 36 +- .../proxy/httpserver/utils_test.go | 175 ++- internal/mocks/mock_proxy.go | 3 +- internal/proxy/impl.go | 10 +- internal/proxy/task_upsert.go | 417 ++++- internal/proxy/task_upsert_test.go | 551 ++++++- internal/proxy/util.go | 40 + pkg/go.mod | 1 - pkg/util/typeutil/ids_checker.go | 212 +++ pkg/util/typeutil/ids_checker_test.go | 372 +++++ pkg/util/typeutil/schema.go | 225 ++- pkg/util/typeutil/schema_test.go | 441 ++++++ tests/go_client/common/response_checker.go | 16 + tests/go_client/go.sum | 2 - .../go_client/testcases/helper/data_helper.go | 8 + tests/go_client/testcases/update_test.go | 206 +++ tests/go_client/testcases/upsert_test.go | 89 ++ tests/integration/util_collection.go | 4 + .../testcases/test_partial_update.py | 1351 +++++++++++++++++ 26 files changed, 4112 insertions(+), 90 deletions(-) create mode 100644 pkg/util/typeutil/ids_checker.go create mode 100644 pkg/util/typeutil/ids_checker_test.go create mode 100644 tests/go_client/testcases/update_test.go create mode 100644 tests/restful_client_v2/testcases/test_partial_update.py diff --git a/client/go.sum b/client/go.sum index 4609a4ed92..431c5ab1b8 100644 --- a/client/go.sum +++ b/client/go.sum @@ -318,8 +318,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6 h1:qTBOTsZ3OwEXkrHRqPn562ddkDqeToIY6CstLIaVQYs= -github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17 h1:zyrKuc0rwT5xWIFkZr/bFWXXYbYvSBMT3iFITnaR8IE= github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/milvus/pkg/v2 v2.0.0-20250319085209-5a6b4e56d59e h1:VCr43pG4efacDbM4au70fh8/5hNTftoWzm1iEumvDWM= diff --git a/client/milvusclient/write_options.go b/client/milvusclient/write_options.go index 80615cdbbd..ac8ef07a88 100644 --- a/client/milvusclient/write_options.go +++ b/client/milvusclient/write_options.go @@ -52,6 +52,7 @@ type columnBasedDataOption struct { collName string partitionName string columns []column.Column + partialUpdate bool } func (opt *columnBasedDataOption) WriteBackPKs(_ *entity.Schema, _ column.Column) error { @@ -253,6 +254,11 @@ func (opt *columnBasedDataOption) WithPartition(partitionName string) *columnBas return opt } +func (opt *columnBasedDataOption) WithPartialUpdate(partialUpdate bool) *columnBasedDataOption { + opt.partialUpdate = partialUpdate + return opt +} + func (opt *columnBasedDataOption) CollectionName() string { return opt.collName } @@ -282,6 +288,7 @@ func (opt *columnBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvu FieldsData: fieldsData, NumRows: uint32(rowNum), SchemaTimestamp: coll.UpdateTimestamp, + PartialUpdate: opt.partialUpdate, }, nil } @@ -340,6 +347,7 @@ func (opt *rowBasedDataOption) UpsertRequest(coll *entity.Collection) (*milvuspb PartitionName: opt.partitionName, FieldsData: fieldsData, NumRows: uint32(rowNum), + PartialUpdate: opt.partialUpdate, }, nil } diff --git a/go.mod b/go.mod index 1a9b760f88..c815918d3e 100644 --- a/go.mod +++ b/go.mod @@ -86,6 +86,13 @@ require ( mosn.io/holmes v1.0.2 ) +require ( + github.com/gopherjs/gopherjs v1.12.80 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/smartystreets/assertions v1.2.0 // indirect + github.com/smartystreets/goconvey v1.7.2 // indirect +) + require ( cloud.google.com/go v0.115.0 // indirect cloud.google.com/go/auth v0.6.1 // indirect @@ -164,7 +171,6 @@ require ( github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.5 // indirect - github.com/gopherjs/gopherjs v1.12.80 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect @@ -177,7 +183,6 @@ require ( github.com/ianlancetaylor/cgosymbolizer v0.0.0-20221217025313-27d3c9f66b6a // indirect github.com/jonboulle/clockwork v0.2.2 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect github.com/kr/pretty v0.3.1 // indirect @@ -220,8 +225,6 @@ require ( github.com/shirou/gopsutil/v3 v3.23.7 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/smartystreets/assertions v1.2.0 // indirect - github.com/smartystreets/goconvey v1.7.2 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/afero v1.6.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 16b110c13b..d830e66c2f 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -755,7 +755,7 @@ func (h *HandlersV1) insert(c *gin.Context) { return nil, RestRequestInterceptorErr } body, _ := c.Get(gin.BodyBytesKey) - err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema) + err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema, false) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) HTTPAbortReturn(c, http.StatusOK, gin.H{ @@ -765,7 +765,7 @@ func (h *HandlersV1) insert(c *gin.Context) { return nil, RestRequestInterceptorErr } insertReq := req.(*milvuspb.InsertRequest) - insertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, true) + insertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, true, false) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) HTTPAbortReturn(c, http.StatusOK, gin.H{ @@ -831,6 +831,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { httpReq.DbName = singleUpsertReq.DbName httpReq.CollectionName = singleUpsertReq.CollectionName httpReq.Data = []map[string]interface{}{singleUpsertReq.Data} + httpReq.PartialUpdate = singleUpsertReq.PartialUpdate } if httpReq.CollectionName == "" || httpReq.Data == nil { log.Warn("high level restful api, upsert require parameter: [collectionName, data], but miss") @@ -844,6 +845,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { DbName: httpReq.DbName, CollectionName: httpReq.CollectionName, NumRows: uint32(len(httpReq.Data)), + PartialUpdate: httpReq.PartialUpdate, } c.Set(ContextRequest, req) username, _ := c.Get(ContextUsername) @@ -861,7 +863,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { } } body, _ := c.Get(gin.BodyBytesKey) - err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema) + err, httpReq.Data, _ = checkAndSetData(body.([]byte), collSchema, httpReq.PartialUpdate) if err != nil { log.Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err)) HTTPAbortReturn(c, http.StatusOK, gin.H{ @@ -871,7 +873,7 @@ func (h *HandlersV1) upsert(c *gin.Context) { return nil, RestRequestInterceptorErr } upsertReq := req.(*milvuspb.UpsertRequest) - upsertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, false) + upsertReq.FieldsData, err = anyToColumns(httpReq.Data, nil, collSchema, false, httpReq.PartialUpdate) if err != nil { log.Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err)) HTTPAbortReturn(c, http.StatusOK, gin.H{ diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 270e624bdb..fd3dd26978 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -981,7 +981,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN } body, _ := c.Get(gin.BodyBytesKey) var validDataMap map[string][]bool - err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema) + err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema, false) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with insert data", zap.Error(err), zap.String("body", string(body.([]byte)))) HTTPAbortReturn(c, http.StatusOK, gin.H{ @@ -992,7 +992,7 @@ func (h *HandlersV2) insert(ctx context.Context, c *gin.Context, anyReq any, dbN } req.NumRows = uint32(len(httpReq.Data)) - req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, true) + req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, true, false) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) HTTPAbortReturn(c, http.StatusOK, gin.H{ @@ -1045,6 +1045,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN DbName: dbName, CollectionName: httpReq.CollectionName, PartitionName: httpReq.PartitionName, + PartialUpdate: httpReq.PartialUpdate, // PartitionName: "_default", } c.Set(ContextRequest, req) @@ -1055,7 +1056,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN } body, _ := c.Get(gin.BodyBytesKey) var validDataMap map[string][]bool - err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema) + err, httpReq.Data, validDataMap = checkAndSetData(body.([]byte), collSchema, httpReq.PartialUpdate) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err)) HTTPAbortReturn(c, http.StatusOK, gin.H{ @@ -1066,7 +1067,7 @@ func (h *HandlersV2) upsert(ctx context.Context, c *gin.Context, anyReq any, dbN } req.NumRows = uint32(len(httpReq.Data)) - req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, false) + req.FieldsData, err = anyToColumns(httpReq.Data, validDataMap, collSchema, false, httpReq.PartialUpdate) if err != nil { log.Ctx(ctx).Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err)) HTTPAbortReturn(c, http.StatusOK, gin.H{ diff --git a/internal/distributed/proxy/httpserver/request.go b/internal/distributed/proxy/httpserver/request.go index db18622ac5..cc14b7b1e6 100644 --- a/internal/distributed/proxy/httpserver/request.go +++ b/internal/distributed/proxy/httpserver/request.go @@ -71,12 +71,14 @@ type UpsertReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" validate:"required"` Data []map[string]interface{} `json:"data" validate:"required"` + PartialUpdate bool `json:"partialUpdate"` } type SingleUpsertReq struct { DbName string `json:"dbName"` CollectionName string `json:"collectionName" validate:"required"` Data map[string]interface{} `json:"data" validate:"required"` + PartialUpdate bool `json:"partialUpdate"` } type SearchReq struct { diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 889ad81f44..2a0337c625 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -272,6 +272,7 @@ type CollectionDataReq struct { CollectionName string `json:"collectionName" binding:"required"` PartitionName string `json:"partitionName"` Data []map[string]interface{} `json:"data" binding:"required"` + PartialUpdate bool `json:"partialUpdate"` } func (req *CollectionDataReq) GetDbName() string { return req.DbName } diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 92b68e20a8..f7e87f6934 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -284,7 +284,7 @@ func printIndexes(indexes []*milvuspb.IndexDescription) []gin.H { // --------------------- insert param --------------------- // -func checkAndSetData(body []byte, collSchema *schemapb.CollectionSchema) (error, []map[string]interface{}, map[string][]bool) { +func checkAndSetData(body []byte, collSchema *schemapb.CollectionSchema, partialUpdate bool) (error, []map[string]interface{}, map[string][]bool) { var reallyDataArray []map[string]interface{} validDataMap := make(map[string][]bool) dataResult := gjson.GetBytes(body, HTTPRequestData) @@ -321,7 +321,14 @@ func checkAndSetData(body []byte, collSchema *schemapb.CollectionSchema) (error, } } - dataString := gjson.Get(data.Raw, fieldName).String() + // For partial update, check if field exists in the data + fieldValue := gjson.Get(data.Raw, fieldName) + if partialUpdate && !fieldValue.Exists() { + // Skip fields that are not provided in partial update + continue + } + + dataString := fieldValue.String() // if has pass pk than just to try to set it if field.IsPrimaryKey && field.AutoID && len(dataString) == 0 { continue @@ -732,7 +739,7 @@ func convertToIntArray(dataType schemapb.DataType, arr interface{}) []int32 { return res } -func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, sch *schemapb.CollectionSchema, inInsert bool) ([]*schemapb.FieldData, error) { +func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, sch *schemapb.CollectionSchema, inInsert bool, partialUpdate bool) ([]*schemapb.FieldData, error) { rowsLen := len(rows) if rowsLen == 0 { return []*schemapb.FieldData{}, errors.New("no row need to be convert to columns") @@ -810,11 +817,12 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, IsDynamic: field.IsDynamic, } } - if len(nameDims) == 0 && len(sch.Functions) == 0 { + if len(nameDims) == 0 && len(sch.Functions) == 0 && !partialUpdate { return nil, fmt.Errorf("collection: %s has no vector field or functions", sch.Name) } dynamicCol := make([][]byte, 0, rowsLen) + fieldLen := make(map[string]int) for _, row := range rows { // collection schema name need not be same, since receiver could have other names @@ -844,8 +852,12 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, continue } if !ok { + if partialUpdate { + continue + } return nil, fmt.Errorf("row %d does not has field %s", idx, field.Name) } + fieldLen[field.Name] += 1 switch field.DataType { case schemapb.DataType_Bool: nameColumns[field.Name] = append(nameColumns[field.Name].([]bool), candi.v.Interface().(bool)) @@ -923,6 +935,22 @@ func anyToColumns(rows []map[string]interface{}, validDataMap map[string][]bool, } columns := make([]*schemapb.FieldData, 0, len(nameColumns)) for name, column := range nameColumns { + if fieldLen[name] == 0 && partialUpdate { + // for partial update, skip update for nullable field + // cause we cannot distinguish between missing fields and fields explicitly set to null + log.Info("skip empty field for partial update", + zap.String("fieldName", name)) + continue + } + if fieldLen[name] != rowsLen && partialUpdate { + // for partial update, if try to update different field in different rows, return error + log.Info("field len is not equal to rows len", + zap.String("fieldName", name), + zap.Int("fieldLen", fieldLen[name]), + zap.Int("rowsLen", rowsLen)) + return nil, fmt.Errorf("column %s has length %d, expected %d", name, fieldLen[name], rowsLen) + } + colData := fieldData[name] switch colData.Type { case schemapb.DataType_Bool: diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index 7ebd61887c..cf73ef46ed 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -604,12 +604,12 @@ func TestAnyToColumns(t *testing.T) { req := InsertReq{} coll := generateCollectionSchema(schemapb.DataType_Int64, false, true) var err error - err, req.Data, _ = checkAndSetData(body, coll) + err, req.Data, _ = checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, int64(0), req.Data[0]["id"]) assert.Equal(t, int64(1), req.Data[0]["book_id"]) assert.Equal(t, int64(2), req.Data[0]["word_count"]) - fieldsData, err := anyToColumns(req.Data, nil, coll, true) + fieldsData, err := anyToColumns(req.Data, nil, coll, true, false) assert.Equal(t, nil, err) assert.Equal(t, true, fieldsData[len(fieldsData)-1].IsDynamic) assert.Equal(t, schemapb.DataType_JSON, fieldsData[len(fieldsData)-1].Type) @@ -621,12 +621,12 @@ func TestAnyToColumns(t *testing.T) { req := InsertReq{} coll := generateCollectionSchema(schemapb.DataType_Int64, false, true) var err error - err, req.Data, _ = checkAndSetData(body, coll) + err, req.Data, _ = checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, int64(0), req.Data[0]["id"]) assert.Equal(t, int64(1), req.Data[0]["book_id"]) assert.Equal(t, int64(2), req.Data[0]["word_count"]) - fieldsData, err := anyToColumns(req.Data, nil, coll, false) + fieldsData, err := anyToColumns(req.Data, nil, coll, false, false) assert.Equal(t, nil, err) assert.Equal(t, true, fieldsData[len(fieldsData)-1].IsDynamic) assert.Equal(t, schemapb.DataType_JSON, fieldsData[len(fieldsData)-1].Type) @@ -638,12 +638,12 @@ func TestAnyToColumns(t *testing.T) { req := InsertReq{} coll := generateCollectionSchema(schemapb.DataType_Int64, true, true) var err error - err, req.Data, _ = checkAndSetData(body, coll) + err, req.Data, _ = checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, int64(0), req.Data[0]["id"]) assert.Equal(t, int64(1), req.Data[0]["book_id"]) assert.Equal(t, int64(2), req.Data[0]["word_count"]) - _, err = anyToColumns(req.Data, nil, coll, true) + _, err = anyToColumns(req.Data, nil, coll, true, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "no need to pass pk field")) }) @@ -652,7 +652,7 @@ func TestAnyToColumns(t *testing.T) { body := []byte("{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2, \"classified\": false, \"databaseID\": null}}") coll := generateCollectionSchema(schemapb.DataType_Int64, true, false) var err error - err, _, _ = checkAndSetData(body, coll) + err, _, _ = checkAndSetData(body, coll, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "has pass more fiel")) }) @@ -662,12 +662,12 @@ func TestAnyToColumns(t *testing.T) { req := InsertReq{} coll := generateCollectionSchema(schemapb.DataType_Int64, false, false) var err error - err, req.Data, _ = checkAndSetData(body, coll) + err, req.Data, _ = checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, int64(1), req.Data[0]["book_id"]) assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"]) assert.Equal(t, int64(2), req.Data[0]["word_count"]) - fieldsData, err := anyToColumns(req.Data, nil, coll, true) + fieldsData, err := anyToColumns(req.Data, nil, coll, true, false) assert.Equal(t, nil, err) assert.Equal(t, 3, len(fieldsData)) assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic) @@ -677,7 +677,7 @@ func TestAnyToColumns(t *testing.T) { body := []byte("{\"data\": { \"book_intro\": [0.1, 0.2], \"word_count\": 2}}") coll := generateCollectionSchema(schemapb.DataType_Int64, false, false) var err error - err, _, _ = checkAndSetData(body, coll) + err, _, _ = checkAndSetData(body, coll, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "strconv.ParseInt: parsing \"\": invalid syntax")) }) @@ -687,11 +687,11 @@ func TestAnyToColumns(t *testing.T) { req := InsertReq{} coll := generateCollectionSchema(schemapb.DataType_Int64, true, false) var err error - err, req.Data, _ = checkAndSetData(body, coll) + err, req.Data, _ = checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"]) assert.Equal(t, int64(2), req.Data[0]["word_count"]) - fieldsData, err := anyToColumns(req.Data, nil, coll, true) + fieldsData, err := anyToColumns(req.Data, nil, coll, true, false) assert.Equal(t, nil, err) assert.Equal(t, 2, len(fieldsData)) assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic) @@ -702,12 +702,12 @@ func TestAnyToColumns(t *testing.T) { req := InsertReq{} coll := generateCollectionSchema(schemapb.DataType_Int64, true, false) var err error - err, req.Data, _ = checkAndSetData(body, coll) + err, req.Data, _ = checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, int64(1), req.Data[0]["book_id"]) assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"]) assert.Equal(t, int64(2), req.Data[0]["word_count"]) - fieldsData, err := anyToColumns(req.Data, nil, coll, false) + fieldsData, err := anyToColumns(req.Data, nil, coll, false, false) assert.Equal(t, nil, err) assert.Equal(t, 3, len(fieldsData)) assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic) @@ -718,16 +718,117 @@ func TestAnyToColumns(t *testing.T) { req := InsertReq{} coll := generateCollectionSchema(schemapb.DataType_Int64, true, false) var err error - err, req.Data, _ = checkAndSetData(body, coll) + err, req.Data, _ = checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, int64(1), req.Data[0]["book_id"]) assert.Equal(t, []float32{0.1, 0.2}, req.Data[0]["book_intro"]) assert.Equal(t, int64(2), req.Data[0]["word_count"]) - fieldsData, err := anyToColumns(req.Data, nil, coll, false) + fieldsData, err := anyToColumns(req.Data, nil, coll, false, false) assert.Equal(t, nil, err) assert.Equal(t, 3, len(fieldsData)) assert.Equal(t, false, fieldsData[len(fieldsData)-1].IsDynamic) }) + + t.Run("partial update with inconsistent fields should fail", func(t *testing.T) { + // Create a simple schema with two fields: a and b + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "id", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + AutoID: false, + }, + { + FieldID: 101, + Name: "a", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "b", + DataType: schemapb.DataType_Int64, + }, + }, + EnableDynamicField: false, + } + + // Create two rows: first row updates only field 'a', second row updates only field 'b' + rows := []map[string]interface{}{ + { + "id": int64(1), + "a": int64(100), // Only field 'a' is provided + }, + { + "id": int64(2), + "b": int64(200), // Only field 'b' is provided + }, + } + + // Test with partial update = true, this should fail + // because different rows are updating different fields + _, err := anyToColumns(rows, nil, schema, false, true) + assert.Error(t, err) + assert.Contains(t, err.Error(), "has length 1, expected 2") + }) + + t.Run("partial update with consistent missing fields should succeed", func(t *testing.T) { + // Create a simple schema with two fields: a and b + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "id", + DataType: schemapb.DataType_Int64, + IsPrimaryKey: true, + AutoID: false, + }, + { + FieldID: 101, + Name: "a", + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 102, + Name: "b", + DataType: schemapb.DataType_Int64, + Nullable: true, // Make field 'b' nullable + }, + }, + EnableDynamicField: false, + } + + // Create two rows: both rows update only field 'a', field 'b' is missing in both + rows := []map[string]interface{}{ + { + "id": int64(1), + "a": int64(100), // Only field 'a' is provided + }, + { + "id": int64(2), + "a": int64(200), // Only field 'a' is provided + }, + } + + // Test with partial update = true, this should succeed + // because the same fields are being updated in all rows + fieldsData, err := anyToColumns(rows, nil, schema, false, true) + assert.NoError(t, err) + assert.NotNil(t, fieldsData) + + // Should have id and a fields, but not b (since it's not provided and nullable) + fieldNames := make(map[string]bool) + for _, fd := range fieldsData { + fieldNames[fd.FieldName] = true + } + assert.True(t, fieldNames["id"]) + assert.True(t, fieldNames["a"]) + // Field 'b' should not be present since it wasn't provided in any row + assert.False(t, fieldNames["b"]) + }) } func TestCheckAndSetData(t *testing.T) { @@ -735,7 +836,7 @@ func TestCheckAndSetData(t *testing.T) { body := []byte("{\"data\": {\"id\": 0,\"$meta\": 2,\"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2, \"classified\": false, \"databaseID\": null}}") coll := generateCollectionSchema(schemapb.DataType_Int64, false, true) var err error - err, _, _ = checkAndSetData(body, coll) + err, _, _ = checkAndSetData(body, coll, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "use the invalid field name")) }) @@ -759,7 +860,7 @@ func TestCheckAndSetData(t *testing.T) { primaryField, floatVectorField, }, EnableDynamicField: true, - }) + }, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field")) err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{ @@ -768,7 +869,7 @@ func TestCheckAndSetData(t *testing.T) { primaryField, binaryVectorField, }, EnableDynamicField: true, - }) + }, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field")) err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{ @@ -777,7 +878,7 @@ func TestCheckAndSetData(t *testing.T) { primaryField, float16VectorField, }, EnableDynamicField: true, - }) + }, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field")) err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{ @@ -786,7 +887,7 @@ func TestCheckAndSetData(t *testing.T) { primaryField, bfloat16VectorField, }, EnableDynamicField: true, - }) + }, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field")) err, _, _ = checkAndSetData(body, &schemapb.CollectionSchema{ @@ -795,7 +896,7 @@ func TestCheckAndSetData(t *testing.T) { primaryField, int8VectorField, }, EnableDynamicField: true, - }) + }, false) assert.Error(t, err) assert.Equal(t, true, strings.HasPrefix(err.Error(), "missing vector field")) }) @@ -809,7 +910,7 @@ func TestCheckAndSetData(t *testing.T) { DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64, }) - err, data, validData := checkAndSetData(body, coll) + err, data, validData := checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, 1, len(data)) assert.Equal(t, 0, len(validData)) @@ -824,7 +925,7 @@ func TestCheckAndSetData(t *testing.T) { DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64, }) - err, data, validData := checkAndSetData(body, coll) + err, data, validData := checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, 1, len(data)) assert.Equal(t, 0, len(validData)) @@ -839,7 +940,7 @@ func TestCheckAndSetData(t *testing.T) { DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64, }) - err, data, validData := checkAndSetData(body, coll) + err, data, validData := checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, 1, len(data)) assert.Equal(t, 0, len(validData)) @@ -855,7 +956,7 @@ func TestInsertWithInt64(t *testing.T) { DataType: schemapb.DataType_Array, ElementType: schemapb.DataType_Int64, }) - err, data, validData := checkAndSetData(body, coll) + err, data, validData := checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, 1, len(data)) assert.Equal(t, 0, len(validData)) @@ -878,7 +979,7 @@ func TestInsertWithNullableField(t *testing.T) { Nullable: true, }) body := []byte("{\"data\": [{\"book_id\": 9999999999999999, \"\nullable\": null,\"book_intro\": [0.1, 0.2], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]},{\"book_id\": 1, \"nullable\": 1,\"book_intro\": [0.3, 0.4], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]}]") - err, data, validData := checkAndSetData(body, coll) + err, data, validData := checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, 2, len(data)) assert.Equal(t, 1, len(validData)) @@ -891,7 +992,7 @@ func TestInsertWithNullableField(t *testing.T) { assert.Equal(t, 4, len(data[0])) assert.Equal(t, 5, len(data[1])) - fieldData, err := anyToColumns(data, validData, coll, true) + fieldData, err := anyToColumns(data, validData, coll, true, false) assert.Equal(t, nil, err) assert.Equal(t, len(coll.Fields), len(fieldData)) } @@ -914,7 +1015,7 @@ func TestInsertWithDefaultValueField(t *testing.T) { }, }) body := []byte("{\"data\": [{\"book_id\": 9999999999999999, \"\fid\": null,\"book_intro\": [0.1, 0.2], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]},{\"book_id\": 1, \"fid\": 1,\"book_intro\": [0.3, 0.4], \"word_count\": 2, \"" + arrayFieldName + "\": [9999999999999999]}]") - err, data, validData := checkAndSetData(body, coll) + err, data, validData := checkAndSetData(body, coll, false) assert.Equal(t, nil, err) assert.Equal(t, 2, len(data)) assert.Equal(t, 1, len(validData)) @@ -927,7 +1028,7 @@ func TestInsertWithDefaultValueField(t *testing.T) { assert.Equal(t, 4, len(data[0])) assert.Equal(t, 5, len(data[1])) - fieldData, err := anyToColumns(data, validData, coll, true) + fieldData, err := anyToColumns(data, validData, coll, true, false) assert.Equal(t, nil, err) assert.Equal(t, len(coll.Fields), len(fieldData)) } @@ -2078,21 +2179,21 @@ func newRowsWithArray(results []map[string]interface{}) []map[string]interface{} func TestArray(t *testing.T) { body, _ := generateRequestBody(schemapb.DataType_Int64) collectionSchema := generateCollectionSchema(schemapb.DataType_Int64, false, true) - err, rows, validRows := checkAndSetData(body, collectionSchema) + err, rows, validRows := checkAndSetData(body, collectionSchema, false) assert.Equal(t, nil, err) assert.Equal(t, 0, len(validRows)) assert.Equal(t, true, compareRows(rows, generateRawRows(schemapb.DataType_Int64), compareRow)) - data, err := anyToColumns(rows, validRows, collectionSchema, true) + data, err := anyToColumns(rows, validRows, collectionSchema, true, false) assert.Equal(t, nil, err) assert.Equal(t, len(collectionSchema.Fields), len(data)) body, _ = generateRequestBodyWithArray(schemapb.DataType_Int64) collectionSchema = newCollectionSchemaWithArray(generateCollectionSchema(schemapb.DataType_Int64, false, true)) - err, rows, validRows = checkAndSetData(body, collectionSchema) + err, rows, validRows = checkAndSetData(body, collectionSchema, false) assert.Equal(t, nil, err) assert.Equal(t, 0, len(validRows)) assert.Equal(t, true, compareRows(rows, newRowsWithArray(generateRawRows(schemapb.DataType_Int64)), compareRow)) - data, err = anyToColumns(rows, validRows, collectionSchema, true) + data, err = anyToColumns(rows, validRows, collectionSchema, true, false) assert.Equal(t, nil, err) assert.Equal(t, len(collectionSchema.Fields), len(data)) } @@ -2175,7 +2276,7 @@ func TestVector(t *testing.T) { }, EnableDynamicField: true, } - err, rows, validRows := checkAndSetData(body, collectionSchema) + err, rows, validRows := checkAndSetData(body, collectionSchema, false) assert.Equal(t, nil, err) for i, row := range rows { assert.Equal(t, 2, len(row[floatVector].([]float32))) @@ -2200,7 +2301,7 @@ func TestVector(t *testing.T) { assert.Equal(t, 16, len(row[sparseFloatVector].([]byte))) } assert.Equal(t, 0, len(validRows)) - data, err := anyToColumns(rows, validRows, collectionSchema, true) + data, err := anyToColumns(rows, validRows, collectionSchema, true, false) assert.Equal(t, nil, err) assert.Equal(t, len(collectionSchema.Fields)+1, len(data)) @@ -2211,7 +2312,7 @@ func TestVector(t *testing.T) { } row[field] = value body, _ = wrapRequestBody([]map[string]interface{}{row}) - err, _, _ = checkAndSetData(body, collectionSchema) + err, _, _ = checkAndSetData(body, collectionSchema, false) assert.Error(t, err) } diff --git a/internal/mocks/mock_proxy.go b/internal/mocks/mock_proxy.go index b0fd9a62fd..1f2146c813 100644 --- a/internal/mocks/mock_proxy.go +++ b/internal/mocks/mock_proxy.go @@ -7650,7 +7650,8 @@ func (_c *MockProxy_Upsert_Call) RunAndReturn(run func(context.Context, *milvusp func NewMockProxy(t interface { mock.TestingT Cleanup(func()) -}) *MockProxy { +}, +) *MockProxy { mock := &MockProxy{} mock.Mock.Test(t) diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index b4ac840e7a..5d883e806f 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -2845,10 +2845,11 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) log := log.Ctx(ctx).With( zap.String("role", typeutil.ProxyRole), - zap.String("db", request.DbName), - zap.String("collection", request.CollectionName), - zap.String("partition", request.PartitionName), - zap.Uint32("NumRows", request.NumRows), + zap.String("db", request.GetDbName()), + zap.String("collection", request.GetCollectionName()), + zap.String("partition", request.GetPartitionName()), + zap.Uint32("NumRows", request.GetNumRows()), + zap.Bool("partialUpdate", request.GetPartialUpdate()), ) log.Debug("Start processing upsert request in Proxy") @@ -2890,6 +2891,7 @@ func (node *Proxy) Upsert(ctx context.Context, request *milvuspb.UpsertRequest) idAllocator: node.rowIDAllocator, chMgr: node.chMgr, schemaTimestamp: request.SchemaTimestamp, + node: node, } log.Debug("Enqueue upsert request in Proxy", diff --git a/internal/proxy/task_upsert.go b/internal/proxy/task_upsert.go index 78b7f712b5..bde01bad29 100644 --- a/internal/proxy/task_upsert.go +++ b/internal/proxy/task_upsert.go @@ -17,6 +17,7 @@ package proxy import ( "context" + "fmt" "strconv" "github.com/cockroachdb/errors" @@ -28,11 +29,14 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + "github.com/milvus-io/milvus/internal/parser/planparserv2" + "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" @@ -66,6 +70,12 @@ type upsertTask struct { // delete task need use the oldIDs oldIDs *schemapb.IDs schemaTimestamp uint64 + + // write after read, generate write part by queryPreExecute + node types.ProxyComponent + + deletePKs *schemapb.IDs + insertFieldData []*schemapb.FieldData } // TraceCtx returns upsertTask context @@ -145,6 +155,361 @@ func (it *upsertTask) OnEnqueue() error { return nil } +func retrieveByPKs(ctx context.Context, t *upsertTask, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) { + log := log.Ctx(ctx).With(zap.String("collectionName", t.req.GetCollectionName())) + var err error + queryReq := &milvuspb.QueryRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Retrieve, + Timestamp: t.BeginTs(), + }, + DbName: t.req.GetDbName(), + CollectionName: t.req.GetCollectionName(), + ConsistencyLevel: commonpb.ConsistencyLevel_Strong, + NotReturnAllMeta: false, + OutputFields: []string{"*"}, + UseDefaultConsistency: false, + GuaranteeTimestamp: t.BeginTs(), + } + pkField, err := typeutil.GetPrimaryFieldSchema(t.schema.CollectionSchema) + if err != nil { + return nil, err + } + + var partitionIDs []int64 + if t.partitionKeyMode { + // multi entities with same pk and diff partition keys may be hashed to multi physical partitions + // if deleteMsg.partitionID = common.InvalidPartition, + // all segments with this pk under the collection will have the delete record + partitionIDs = []int64{common.AllPartitionsID} + queryReq.PartitionNames = []string{} + } else { + // partition name could be defaultPartitionName or name specified by sdk + partName := t.upsertMsg.DeleteMsg.PartitionName + if err := validatePartitionTag(partName, true); err != nil { + log.Warn("Invalid partition name", zap.String("partitionName", partName), zap.Error(err)) + return nil, err + } + partID, err := globalMetaCache.GetPartitionID(ctx, t.req.GetDbName(), t.req.GetCollectionName(), partName) + if err != nil { + log.Warn("Failed to get partition id", zap.String("partitionName", partName), zap.Error(err)) + return nil, err + } + partitionIDs = []int64{partID} + queryReq.PartitionNames = []string{partName} + } + + plan := planparserv2.CreateRequeryPlan(pkField, ids) + qt := &queryTask{ + ctx: t.ctx, + Condition: NewTaskCondition(t.ctx), + RetrieveRequest: &internalpb.RetrieveRequest{ + Base: commonpbutil.NewMsgBase( + commonpbutil.WithMsgType(commonpb.MsgType_Retrieve), + commonpbutil.WithSourceID(paramtable.GetNodeID()), + ), + ReqID: paramtable.GetNodeID(), + PartitionIDs: partitionIDs, + ConsistencyLevel: commonpb.ConsistencyLevel_Strong, + }, + request: queryReq, + plan: plan, + mixCoord: t.node.(*Proxy).mixCoord, + lb: t.node.(*Proxy).lbPolicy, + } + + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-retrieveByPKs") + defer func() { + sp.End() + }() + queryResult, err := t.node.(*Proxy).query(ctx, qt, sp) + if err := merr.CheckRPCCall(queryResult.GetStatus(), err); err != nil { + return nil, err + } + return queryResult, err +} + +func (it *upsertTask) queryPreExecute(ctx context.Context) error { + log := log.Ctx(ctx).With(zap.String("collectionName", it.req.CollectionName)) + + primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema.CollectionSchema) + if err != nil { + log.Warn("get primary field schema failed", zap.Error(err)) + return err + } + + primaryFieldData, err := typeutil.GetPrimaryFieldData(it.req.GetFieldsData(), primaryFieldSchema) + if err != nil { + log.Error("get primary field data failed", zap.Error(err)) + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldSchema.Name)) + } + + oldIDs, err := parsePrimaryFieldData2IDs(primaryFieldData) + if err != nil { + log.Warn("parse primary field data to IDs failed", zap.Error(err)) + return err + } + + oldIDSize := typeutil.GetSizeOfIDs(oldIDs) + if oldIDSize == 0 { + it.deletePKs = &schemapb.IDs{} + it.insertFieldData = it.req.GetFieldsData() + log.Info("old records not found, just do insert") + return nil + } + + tr := timerecord.NewTimeRecorder("Proxy-Upsert-retrieveByPKs") + // retrieve by primary key to get original field data + resp, err := retrieveByPKs(ctx, it, oldIDs, []string{"*"}) + if err != nil { + log.Info("retrieve by primary key failed", zap.Error(err)) + return err + } + + if len(resp.GetFieldsData()) == 0 { + return merr.WrapErrParameterInvalidMsg("retrieve by primary key failed, no data found") + } + + existFieldData := resp.GetFieldsData() + pkFieldData, err := typeutil.GetPrimaryFieldData(existFieldData, primaryFieldSchema) + if err != nil { + log.Error("get primary field data failed", zap.Error(err)) + return err + } + existIDs, err := parsePrimaryFieldData2IDs(pkFieldData) + if err != nil { + log.Info("parse primary field data to ids failed", zap.Error(err)) + return err + } + log.Info("retrieveByPKs cost", + zap.Int("resultNum", typeutil.GetSizeOfIDs(existIDs)), + zap.Int64("latency", tr.ElapseSpan().Milliseconds())) + + // check whether the primary key is exist in query result + idsChecker, err := typeutil.NewIDsChecker(existIDs) + if err != nil { + log.Info("create primary key checker failed", zap.Error(err)) + return err + } + + // Build mapping from existing primary keys to their positions in query result + // This ensures we can correctly locate data even if query results are not in the same order as request + existIDsLen := typeutil.GetSizeOfIDs(existIDs) + existPKToIndex := make(map[interface{}]int, existIDsLen) + for j := 0; j < existIDsLen; j++ { + pk := typeutil.GetPK(existIDs, int64(j)) + existPKToIndex[pk] = j + } + + // set field id for user passed field data + upsertFieldData := it.upsertMsg.InsertMsg.GetFieldsData() + if len(upsertFieldData) == 0 { + return merr.WrapErrParameterInvalidMsg("upsert field data is empty") + } + for _, fieldData := range upsertFieldData { + fieldName := fieldData.GetFieldName() + if fieldData.GetIsDynamic() { + fieldName = "$meta" + } + fieldID, ok := it.schema.MapFieldID(fieldName) + if !ok { + log.Info("field not found in schema", zap.Any("field", fieldData)) + return merr.WrapErrParameterInvalidMsg("field not found in schema") + } + fieldData.FieldId = fieldID + fieldData.FieldName = fieldName + } + + lackOfFieldErr := LackOfFieldsDataBySchema(it.schema.CollectionSchema, it.upsertMsg.InsertMsg.GetFieldsData(), false, true) + it.deletePKs = &schemapb.IDs{} + it.insertFieldData = make([]*schemapb.FieldData, len(existFieldData)) + for i := 0; i < oldIDSize; i++ { + exist, err := idsChecker.Contains(oldIDs, i) + if err != nil { + log.Info("check primary key exist in query result failed", zap.Error(err)) + return err + } + + if exist { + // treat upsert as update + // 1. if pk exist in query result, add it to deletePKs + typeutil.AppendIDs(it.deletePKs, oldIDs, i) + // 2. construct the field data for update using correct index mapping + oldPK := typeutil.GetPK(oldIDs, int64(i)) + existIndex, ok := existPKToIndex[oldPK] + if !ok { + return merr.WrapErrParameterInvalidMsg("primary key not found in exist data mapping") + } + typeutil.AppendFieldData(it.insertFieldData, existFieldData, int64(existIndex)) + err := typeutil.UpdateFieldData(it.insertFieldData, upsertFieldData, int64(i)) + if err != nil { + log.Info("update field data failed", zap.Error(err)) + return err + } + } else { + // treat upsert as insert + if lackOfFieldErr != nil { + log.Info("check fields data by schema failed", zap.Error(lackOfFieldErr)) + return lackOfFieldErr + } + // use field data from upsert request + typeutil.AppendFieldData(it.insertFieldData, upsertFieldData, int64(i)) + } + } + + for _, fieldData := range it.insertFieldData { + if fieldData.GetIsDynamic() { + continue + } + fieldSchema, err := it.schema.schemaHelper.GetFieldFromName(fieldData.GetFieldName()) + if err != nil { + log.Info("get field schema failed", zap.Error(err)) + return err + } + + // Note: Since protobuf cannot correctly identify null values, zero values + valid data are used to identify null values, + // therefore for field data obtained from query results, if the field is nullable, it needs to be set to empty values + if fieldSchema.GetNullable() { + if getValidNumber(fieldData.GetValidData()) != len(fieldData.GetValidData()) { + err := ResetNullFieldData(fieldData, fieldSchema) + if err != nil { + log.Info("reset null field data failed", zap.Error(err)) + return err + } + } + } + + // Note: For fields containing default values, default values need to be set according to valid data during insertion, + // but query results fields do not set valid data when returning default value fields, + // therefore valid data needs to be manually set to true + if fieldSchema.GetDefaultValue() != nil { + fieldData.ValidData = make([]bool, oldIDSize) + for i := range fieldData.ValidData { + fieldData.ValidData[i] = true + } + } + } + + return nil +} + +func ResetNullFieldData(field *schemapb.FieldData, fieldSchema *schemapb.FieldSchema) error { + if !fieldSchema.GetNullable() { + return nil + } + + switch field.Field.(type) { + case *schemapb.FieldData_Scalars: + switch sd := field.GetScalars().GetData().(type) { + case *schemapb.ScalarField_BoolData: + validRowNum := getValidNumber(field.GetValidData()) + if validRowNum == 0 { + sd.BoolData.Data = make([]bool, 0) + } else { + ret := make([]bool, validRowNum) + for i, valid := range field.GetValidData() { + if valid { + ret = append(ret, sd.BoolData.Data[i]) + } + } + sd.BoolData.Data = ret + } + + case *schemapb.ScalarField_IntData: + validRowNum := getValidNumber(field.GetValidData()) + if validRowNum == 0 { + sd.IntData.Data = make([]int32, 0) + } else { + ret := make([]int32, validRowNum) + for i, valid := range field.GetValidData() { + if valid { + ret = append(ret, sd.IntData.Data[i]) + } + } + sd.IntData.Data = ret + } + + case *schemapb.ScalarField_LongData: + validRowNum := getValidNumber(field.GetValidData()) + if validRowNum == 0 { + sd.LongData.Data = make([]int64, 0) + } else { + ret := make([]int64, validRowNum) + for i, valid := range field.GetValidData() { + if valid { + ret = append(ret, sd.LongData.Data[i]) + } + } + sd.LongData.Data = ret + } + + case *schemapb.ScalarField_FloatData: + validRowNum := getValidNumber(field.GetValidData()) + if validRowNum == 0 { + sd.FloatData.Data = make([]float32, 0) + } else { + ret := make([]float32, validRowNum) + for i, valid := range field.GetValidData() { + if valid { + ret = append(ret, sd.FloatData.Data[i]) + } + } + sd.FloatData.Data = ret + } + + case *schemapb.ScalarField_DoubleData: + validRowNum := getValidNumber(field.GetValidData()) + if validRowNum == 0 { + sd.DoubleData.Data = make([]float64, 0) + } else { + ret := make([]float64, validRowNum) + for i, valid := range field.GetValidData() { + if valid { + ret = append(ret, sd.DoubleData.Data[i]) + } + } + sd.DoubleData.Data = ret + } + + case *schemapb.ScalarField_StringData: + validRowNum := getValidNumber(field.GetValidData()) + if validRowNum == 0 { + sd.StringData.Data = make([]string, 0) + } else { + ret := make([]string, validRowNum) + for i, valid := range field.GetValidData() { + if valid { + ret = append(ret, sd.StringData.Data[i]) + } + } + sd.StringData.Data = ret + } + + case *schemapb.ScalarField_JsonData: + validRowNum := getValidNumber(field.GetValidData()) + if validRowNum == 0 { + sd.JsonData.Data = make([][]byte, 0) + } else { + ret := make([][]byte, validRowNum) + for i, valid := range field.GetValidData() { + if valid { + ret = append(ret, sd.JsonData.Data[i]) + } + } + sd.JsonData.Data = ret + } + + default: + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String())) + } + + default: + return merr.WrapErrParameterInvalidMsg(fmt.Sprintf("undefined data type:%s", field.Type.String())) + } + + return nil +} + func (it *upsertTask) insertPreExecute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Upsert-insertPreExecute") defer sp.End() @@ -154,8 +519,20 @@ func (it *upsertTask) insertPreExecute(ctx context.Context) error { return err } + bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(it.schema.CollectionSchema)...) // Calculate embedding fields if function.HasNonBM25Functions(it.schema.CollectionSchema.Functions, []int64{}) { + if it.req.PartialUpdate { + // remove the old bm25 fields + ret := make([]*schemapb.FieldData, 0) + for _, fieldData := range it.upsertMsg.InsertMsg.GetFieldsData() { + if bm25Fields.Contain(fieldData.GetFieldName()) { + continue + } + ret = append(ret, fieldData) + } + it.upsertMsg.InsertMsg.FieldsData = ret + } ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Proxy-Upsert-insertPreExecute-call-function-udf") defer sp.End() exec, err := function.NewFunctionExecutor(it.schema.CollectionSchema) @@ -271,17 +648,19 @@ func (it *upsertTask) deletePreExecute(ctx context.Context) error { log := log.Ctx(ctx).With( zap.String("collectionName", collName)) + if it.upsertMsg.DeleteMsg.PrimaryKeys == nil { + // if primary keys are not set by queryPreExecute, use oldIDs to delete all given records + it.upsertMsg.DeleteMsg.PrimaryKeys = it.oldIDs + } + if typeutil.GetSizeOfIDs(it.upsertMsg.DeleteMsg.PrimaryKeys) == 0 { + log.Info("deletePKs is empty, skip deleteExecute") + return nil + } + if err := validateCollectionName(collName); err != nil { log.Info("Invalid collectionName", zap.Error(err)) return err } - collID, err := globalMetaCache.GetCollectionID(ctx, it.req.GetDbName(), collName) - if err != nil { - log.Info("Failed to get collection id", zap.Error(err)) - return err - } - it.upsertMsg.DeleteMsg.CollectionID = collID - it.collectionID = collID if it.partitionKeyMode { // multi entities with same pk and diff partition keys may be hashed to multi physical partitions @@ -335,11 +714,14 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { return merr.WrapErrCollectionReplicateMode("upsert") } + // check collection exists collID, err := globalMetaCache.GetCollectionID(context.Background(), it.req.GetDbName(), collectionName) if err != nil { log.Warn("fail to get collection id", zap.Error(err)) return err } + it.collectionID = collID + colInfo, err := globalMetaCache.GetCollectionInfo(ctx, it.req.GetDbName(), collectionName, collID) if err != nil { log.Warn("fail to get collection info", zap.Error(err)) @@ -398,6 +780,7 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { commonpbutil.WithSourceID(paramtable.GetNodeID()), ), CollectionName: it.req.CollectionName, + CollectionID: it.collectionID, PartitionName: it.req.PartitionName, FieldsData: it.req.FieldsData, NumRows: uint64(it.req.NumRows), @@ -413,12 +796,30 @@ func (it *upsertTask) PreExecute(ctx context.Context) error { ), DbName: it.req.DbName, CollectionName: it.req.CollectionName, + CollectionID: it.collectionID, NumRows: int64(it.req.NumRows), PartitionName: it.req.PartitionName, - CollectionID: it.collectionID, }, }, } + + // check if num_rows is valid + if it.req.NumRows <= 0 { + return merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(it.req.NumRows), "num_rows should be greater than 0") + } + + if it.req.GetPartialUpdate() { + err = it.queryPreExecute(ctx) + if err != nil { + log.Warn("Fail to queryPreExecute", zap.Error(err)) + return err + } + // reconstruct upsert msg after queryPreExecute + it.upsertMsg.InsertMsg.FieldsData = it.insertFieldData + it.upsertMsg.DeleteMsg.PrimaryKeys = it.deletePKs + it.upsertMsg.DeleteMsg.NumRows = int64(typeutil.GetSizeOfIDs(it.deletePKs)) + } + err = it.insertPreExecute(ctx) if err != nil { log.Warn("Fail to insertPreExecute", zap.Error(err)) diff --git a/internal/proxy/task_upsert_test.go b/internal/proxy/task_upsert_test.go index 81cf15aae9..386dec2f82 100644 --- a/internal/proxy/task_upsert_test.go +++ b/internal/proxy/task_upsert_test.go @@ -19,6 +19,7 @@ import ( "context" "testing" + "github.com/bytedance/mockey" "github.com/cockroachdb/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -28,14 +29,18 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/allocator" + grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/util/function" "github.com/milvus-io/milvus/pkg/v2/mq/msgstream" + "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/testutils" + "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) func TestUpsertTask_CheckAligned(t *testing.T) { @@ -156,7 +161,7 @@ func TestUpsertTask_CheckAligned(t *testing.T) { case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows) case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData err = case2.upsertMsg.InsertMsg.CheckAligned() - assert.NoError(t, err) + assert.NoError(t, nil, err) // less int8 data case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2) @@ -520,6 +525,7 @@ func TestUpsertTaskForSchemaMismatch(t *testing.T) { ctx: ctx, req: &milvuspb.UpsertRequest{ CollectionName: "col-0", + NumRows: 10, }, schemaTimestamp: 99, } @@ -533,3 +539,546 @@ func TestUpsertTaskForSchemaMismatch(t *testing.T) { assert.ErrorIs(t, err, merr.ErrCollectionSchemaMismatch) }) } + +// Helper function to create test updateTask +func createTestUpdateTask() *upsertTask { + mcClient := &grpcmixcoordclient.Client{} + + upsertTask := &upsertTask{ + baseTask: baseTask{}, + Condition: NewTaskCondition(context.Background()), + req: &milvuspb.UpsertRequest{ + DbName: "test_db", + CollectionName: "test_collection", + PartitionName: "_default", + FieldsData: []*schemapb.FieldData{ + { + FieldName: "id", + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}, + }, + }, + }, + }, + { + FieldName: "name", + FieldId: 102, + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}}, + }, + }, + }, + }, + { + FieldName: "vector", + FieldId: 101, + Type: schemapb.DataType_FloatVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 128, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{Data: make([]float32, 384)}, // 3 * 128 + }, + }, + }, + }, + }, + NumRows: 3, + }, + ctx: context.Background(), + schema: createTestSchema(), + collectionID: 1001, + node: &Proxy{ + mixCoord: mcClient, + lbPolicy: NewLBPolicyImpl(nil), + }, + } + + return upsertTask +} + +// Helper function to create test schema +func createTestSchema() *schemaInfo { + schema := &schemapb.CollectionSchema{ + Name: "test_collection", + Fields: []*schemapb.FieldSchema{ + { + FieldID: 100, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, + { + FieldID: 101, + Name: "vector", + DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "128"}, + }, + }, + { + FieldID: 102, + Name: "name", + DataType: schemapb.DataType_VarChar, + }, + }, + } + return newSchemaInfo(schema) +} + +func TestRetrieveByPKs_Success(t *testing.T) { + mockey.PatchConvey("TestRetrieveByPKs_Success", t, func() { + // Setup mocks + mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{ + FieldID: 100, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, nil).Build() + + mockey.Mock(validatePartitionTag).Return(nil).Build() + + mockey.Mock((*MetaCache).GetPartitionID).Return(int64(1002), nil).Build() + + mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build() + + mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + { + FieldName: "id", + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{1, 2}}, + }, + }, + }, + }, + }, + }, nil).Build() + + globalMetaCache = &MetaCache{} + mockey.Mock(globalMetaCache.GetPartitionID).Return(int64(1002), nil).Build() + + // Execute test + task := createTestUpdateTask() + task.partitionKeyMode = false + task.upsertMsg = &msgstream.UpsertMsg{ + InsertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + PartitionName: "_default", + }, + }, + DeleteMsg: &msgstream.DeleteMsg{ + DeleteRequest: &msgpb.DeleteRequest{ + PartitionName: "_default", + }, + }, + } + + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 2}}, + }, + } + + result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) + + // Verify results + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, commonpb.ErrorCode_Success, result.Status.ErrorCode) + assert.Len(t, result.FieldsData, 1) + }) +} + +func TestRetrieveByPKs_GetPrimaryFieldSchemaError(t *testing.T) { + mockey.PatchConvey("TestRetrieveByPKs_GetPrimaryFieldSchemaError", t, func() { + expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found") + mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build() + + task := createTestUpdateTask() + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 2}}, + }, + } + + result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "primary field not found") + }) +} + +func TestRetrieveByPKs_PartitionKeyMode(t *testing.T) { + mockey.PatchConvey("TestRetrieveByPKs_PartitionKeyMode", t, func() { + mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{ + FieldID: 100, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, nil).Build() + + mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build() + + mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{}, + }, nil).Build() + + task := createTestUpdateTask() + task.partitionKeyMode = true + + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 2}}, + }, + } + + result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"}) + + assert.NoError(t, err) + assert.NotNil(t, result) + }) +} + +func TestUpdateTask_queryPreExecute_Success(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_queryPreExecute_Success", t, func() { + // Setup mocks + mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{ + FieldID: 100, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, nil).Build() + + mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{ + FieldName: "id", + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}, + }, + }, + }, + }, nil).Build() + + mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}}, + }, + }, nil).Build() + + mockey.Mock(typeutil.GetSizeOfIDs).Return(3).Build() + + mockey.Mock(retrieveByPKs).Return(&milvuspb.QueryResults{ + Status: merr.Success(), + FieldsData: []*schemapb.FieldData{ + { + FieldName: "id", + FieldId: 100, + Type: schemapb.DataType_Int64, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{Data: []int64{1, 2}}, + }, + }, + }, + }, + { + FieldName: "name", + FieldId: 102, + Type: schemapb.DataType_VarChar, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}}, + }, + }, + }, + }, + { + FieldName: "vector", + FieldId: 101, + Type: schemapb.DataType_FloatVector, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: 128, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{Data: make([]float32, 256)}, // 2 * 128 + }, + }, + }, + }, + }, + }, nil).Build() + + mockey.Mock(typeutil.NewIDsChecker).Return(&typeutil.IDsChecker{}, nil).Build() + + // Execute test + task := createTestUpdateTask() + task.schema = createTestSchema() + task.upsertMsg = &msgstream.UpsertMsg{ + InsertMsg: &msgstream.InsertMsg{ + InsertRequest: &msgpb.InsertRequest{ + FieldsData: []*schemapb.FieldData{ + { + FieldName: "id", + FieldId: 100, + Type: schemapb.DataType_Int64, + }, + { + FieldName: "name", + FieldId: 102, + Type: schemapb.DataType_VarChar, + }, + { + FieldName: "vector", + FieldId: 101, + Type: schemapb.DataType_FloatVector, + }, + }, + }, + }, + } + + err := task.queryPreExecute(context.Background()) + + // Verify results + assert.NoError(t, err) + assert.NotNil(t, task.deletePKs) + assert.NotNil(t, task.insertFieldData) + }) +} + +func TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError", t, func() { + expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found") + mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build() + + task := createTestUpdateTask() + task.schema = createTestSchema() + + err := task.queryPreExecute(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "primary field not found") + }) +} + +func TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError", t, func() { + mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{ + FieldID: 100, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, nil).Build() + + expectedErr := merr.WrapErrParameterInvalidMsg("primary field data not found") + mockey.Mock(typeutil.GetPrimaryFieldData).Return(nil, expectedErr).Build() + + task := createTestUpdateTask() + task.schema = createTestSchema() + + err := task.queryPreExecute(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "must assign pk when upsert") + }) +} + +func TestUpdateTask_queryPreExecute_EmptyOldIDs(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_queryPreExecute_EmptyOldIDs", t, func() { + mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{ + FieldID: 100, + Name: "id", + IsPrimaryKey: true, + DataType: schemapb.DataType_Int64, + }, nil).Build() + + mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{ + FieldName: "id", + FieldId: 100, + Type: schemapb.DataType_Int64, + }, nil).Build() + + mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{}, nil).Build() + + mockey.Mock(typeutil.GetSizeOfIDs).Return(0).Build() + + task := createTestUpdateTask() + task.schema = createTestSchema() + + err := task.queryPreExecute(context.Background()) + + assert.NoError(t, err) + assert.NotNil(t, task.deletePKs) + assert.Equal(t, task.req.GetFieldsData(), task.insertFieldData) + }) +} + +func TestUpdateTask_PreExecute_Success(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_PreExecute_Success", t, func() { + // Setup mocks + globalMetaCache = &MetaCache{} + + mockey.Mock(GetReplicateID).Return("", nil).Build() + + mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build() + + mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{ + updateTimestamp: 12345, + }, nil).Build() + + mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build() + + mockey.Mock(isPartitionKeyMode).Return(false, nil).Build() + + mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{ + name: "_default", + }, nil).Build() + + mockey.Mock((*upsertTask).queryPreExecute).Return(nil).Build() + + mockey.Mock((*upsertTask).insertPreExecute).Return(nil).Build() + + mockey.Mock((*upsertTask).deletePreExecute).Return(nil).Build() + + // Execute test + task := createTestUpdateTask() + task.req.PartialUpdate = true + + err := task.PreExecute(context.Background()) + + // Verify results + assert.NoError(t, err) + assert.NotNil(t, task.result) + assert.Equal(t, int64(1001), task.collectionID) + assert.NotNil(t, task.schema) + assert.NotNil(t, task.upsertMsg) + }) +} + +func TestUpdateTask_PreExecute_ReplicateIDError(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_PreExecute_ReplicateIDError", t, func() { + globalMetaCache = &MetaCache{} + + mockey.Mock(GetReplicateID).Return("replica1", nil).Build() + + task := createTestUpdateTask() + + err := task.PreExecute(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "can't operate on the collection under standby mode") + }) +} + +func TestUpdateTask_PreExecute_GetCollectionIDError(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_PreExecute_GetCollectionIDError", t, func() { + globalMetaCache = &MetaCache{} + + mockey.Mock(GetReplicateID).Return("", nil).Build() + + expectedErr := merr.WrapErrCollectionNotFound("test_collection") + mockey.Mock((*MetaCache).GetCollectionID).Return(int64(0), expectedErr).Build() + + task := createTestUpdateTask() + + err := task.PreExecute(context.Background()) + + assert.Error(t, err) + }) +} + +func TestUpdateTask_PreExecute_PartitionKeyModeError(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_PreExecute_PartitionKeyModeError", t, func() { + globalMetaCache = &MetaCache{} + + mockey.Mock(GetReplicateID).Return("", nil).Build() + mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build() + mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{ + updateTimestamp: 12345, + }, nil).Build() + mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build() + + mockey.Mock(isPartitionKeyMode).Return(true, nil).Build() + + task := createTestUpdateTask() + task.req.PartitionName = "custom_partition" // This should cause error in partition key mode + + err := task.PreExecute(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "not support manually specifying the partition names if partition key mode is used") + }) +} + +func TestUpdateTask_PreExecute_InvalidNumRows(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_PreExecute_InvalidNumRows", t, func() { + globalMetaCache = &MetaCache{} + + mockey.Mock(GetReplicateID).Return("", nil).Build() + mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build() + mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{ + updateTimestamp: 12345, + }, nil).Build() + mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build() + mockey.Mock(isPartitionKeyMode).Return(false, nil).Build() + mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{ + name: "_default", + }, nil).Build() + + task := createTestUpdateTask() + task.req.NumRows = 0 // Invalid num_rows + + err := task.PreExecute(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid num_rows") + }) +} + +func TestUpdateTask_PreExecute_QueryPreExecuteError(t *testing.T) { + mockey.PatchConvey("TestUpdateTask_PreExecute_QueryPreExecuteError", t, func() { + globalMetaCache = &MetaCache{} + + mockey.Mock(GetReplicateID).Return("", nil).Build() + mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build() + mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{ + updateTimestamp: 12345, + }, nil).Build() + mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build() + mockey.Mock(isPartitionKeyMode).Return(false, nil).Build() + mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{ + name: "_default", + }, nil).Build() + + expectedErr := merr.WrapErrParameterInvalidMsg("query pre-execute failed") + mockey.Mock((*upsertTask).queryPreExecute).Return(expectedErr).Build() + + task := createTestUpdateTask() + task.req.PartialUpdate = true + + err := task.PreExecute(context.Background()) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "query pre-execute failed") + }) +} diff --git a/internal/proxy/util.go b/internal/proxy/util.go index c5f48c576b..48949a818c 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1934,6 +1934,38 @@ func checkPrimaryFieldData(allFields []*schemapb.FieldSchema, schema *schemapb.C return ids, nil } +// check whether insertMsg has all fields in schema +func LackOfFieldsDataBySchema(schema *schemapb.CollectionSchema, fieldsData []*schemapb.FieldData, skipPkFieldCheck bool, skipDynamicFieldCheck bool) error { + log := log.With(zap.String("collection", schema.GetName())) + + // find bm25 generated fields + bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(schema)...) + dataNameMap := make(map[string]*schemapb.FieldData) + for _, data := range fieldsData { + dataNameMap[data.GetFieldName()] = data + } + + for _, fieldSchema := range schema.Fields { + if bm25Fields.Contain(fieldSchema.GetName()) { + continue + } + + if _, ok := dataNameMap[fieldSchema.GetName()]; !ok { + if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && skipPkFieldCheck) || + IsBM25FunctionOutputField(fieldSchema, schema) || + (skipDynamicFieldCheck && fieldSchema.GetIsDynamic()) { + // autoGenField + continue + } + + log.Info("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName())) + return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName()) + } + } + + return nil +} + // for some varchar with analzyer // we need check char format before insert it to message queue // now only support utf-8 @@ -2490,6 +2522,14 @@ func IsBM25FunctionOutputField(field *schemapb.FieldSchema, collSchema *schemapb return false } +func GetFunctionOutputFields(collSchema *schemapb.CollectionSchema) []string { + fields := make([]string, 0) + for _, fSchema := range collSchema.Functions { + fields = append(fields, fSchema.OutputFieldNames...) + } + return fields +} + func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 { properties := make(map[string]string) for _, pair := range pairs { diff --git a/pkg/go.mod b/pkg/go.mod index 85313a156b..12beb0274c 100644 --- a/pkg/go.mod +++ b/pkg/go.mod @@ -216,7 +216,6 @@ replace ( github.com/go-kit/kit => github.com/go-kit/kit v0.1.0 github.com/golang-jwt/jwt => github.com/golang-jwt/jwt/v4 v4.5.2 // indirect github.com/ianlancetaylor/cgosymbolizer => github.com/milvus-io/cgosymbolizer v0.0.0-20250318084424-114f4050c3a6 - github.com/milvus-io/milvus-proto/go-api/v2 v2.6.0-rc.1.0.20250716031043-88051c3893ce => /home/zhicheng/SecWorkspace/milvus-proto/go-api github.com/streamnative/pulsarctl => github.com/xiaofan-luan/pulsarctl v0.5.1 github.com/tecbot/gorocksdb => github.com/milvus-io/gorocksdb v0.0.0-20220624081344-8c5f4212846b // indirect ) diff --git a/pkg/util/typeutil/ids_checker.go b/pkg/util/typeutil/ids_checker.go new file mode 100644 index 0000000000..c09434c839 --- /dev/null +++ b/pkg/util/typeutil/ids_checker.go @@ -0,0 +1,212 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeutil + +import ( + "fmt" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +// IDsChecker provides efficient lookup functionality for schema.IDs +// It supports checking if an ID at a specific position in one IDs exists in another IDs +type IDsChecker struct { + intIDSet map[int64]struct{} + strIDSet map[string]struct{} + idType schemapb.DataType +} + +// NewIDsChecker creates a new IDsChecker from the given IDs +// The checker will build an internal set for fast lookup +func NewIDsChecker(ids *schemapb.IDs) (*IDsChecker, error) { + if ids == nil || ids.GetIdField() == nil { + return &IDsChecker{}, nil + } + + checker := &IDsChecker{} + + switch ids.GetIdField().(type) { + case *schemapb.IDs_IntId: + checker.idType = schemapb.DataType_Int64 + checker.intIDSet = make(map[int64]struct{}) + data := ids.GetIntId().GetData() + for _, id := range data { + checker.intIDSet[id] = struct{}{} + } + case *schemapb.IDs_StrId: + checker.idType = schemapb.DataType_VarChar + checker.strIDSet = make(map[string]struct{}) + data := ids.GetStrId().GetData() + for _, id := range data { + checker.strIDSet[id] = struct{}{} + } + default: + return nil, fmt.Errorf("unsupported ID type in IDs") + } + + return checker, nil +} + +// Contains checks if the ID at the specified cursor position in idsA exists in this checker +// Returns true if the ID exists, false otherwise +// Returns error if cursor is out of bounds or type mismatch +func (c *IDsChecker) Contains(idsA *schemapb.IDs, cursor int) (bool, error) { + if idsA == nil || idsA.GetIdField() == nil { + return false, fmt.Errorf("idsA is nil or empty") + } + + // Check if cursor is within bounds + size := GetSizeOfIDs(idsA) + if cursor < 0 || cursor >= size { + return false, fmt.Errorf("cursor %d is out of bounds [0, %d)", cursor, size) + } + + // If checker is empty, return false for any query + if c.IsEmpty() { + return false, nil + } + + switch idsA.GetIdField().(type) { + case *schemapb.IDs_IntId: + if c.idType != schemapb.DataType_Int64 { + return false, fmt.Errorf("type mismatch: checker expects %v, got Int64", c.idType) + } + if c.intIDSet == nil { + return false, nil + } + id := idsA.GetIntId().GetData()[cursor] + _, exists := c.intIDSet[id] + return exists, nil + + case *schemapb.IDs_StrId: + if c.idType != schemapb.DataType_VarChar { + return false, fmt.Errorf("type mismatch: checker expects %v, got VarChar", c.idType) + } + if c.strIDSet == nil { + return false, nil + } + id := idsA.GetStrId().GetData()[cursor] + _, exists := c.strIDSet[id] + return exists, nil + + default: + return false, fmt.Errorf("unsupported ID type in idsA") + } +} + +// ContainsAny checks if any ID in idsA exists in this checker +// Returns the indices of IDs that exist in the checker +func (c *IDsChecker) ContainsAny(idsA *schemapb.IDs) ([]int, error) { + if idsA == nil || idsA.GetIdField() == nil { + return nil, fmt.Errorf("idsA is nil or empty") + } + + var result []int + + // If checker is empty, return empty result for any query + if c.IsEmpty() { + return result, nil + } + + switch idsA.GetIdField().(type) { + case *schemapb.IDs_IntId: + if c.idType != schemapb.DataType_Int64 { + return nil, fmt.Errorf("type mismatch: checker expects %v, got Int64", c.idType) + } + if c.intIDSet == nil { + return result, nil + } + data := idsA.GetIntId().GetData() + for i, id := range data { + if _, exists := c.intIDSet[id]; exists { + result = append(result, i) + } + } + + case *schemapb.IDs_StrId: + if c.idType != schemapb.DataType_VarChar { + return nil, fmt.Errorf("type mismatch: checker expects %v, got VarChar", c.idType) + } + if c.strIDSet == nil { + return result, nil + } + data := idsA.GetStrId().GetData() + for i, id := range data { + if _, exists := c.strIDSet[id]; exists { + result = append(result, i) + } + } + + default: + return nil, fmt.Errorf("unsupported ID type in idsA") + } + + return result, nil +} + +// Size returns the number of unique IDs in this checker +func (c *IDsChecker) Size() int { + switch c.idType { + case schemapb.DataType_Int64: + if c.intIDSet == nil { + return 0 + } + return len(c.intIDSet) + case schemapb.DataType_VarChar: + if c.strIDSet == nil { + return 0 + } + return len(c.strIDSet) + default: + return 0 + } +} + +// IsEmpty returns true if the checker contains no IDs +func (c *IDsChecker) IsEmpty() bool { + return c.Size() == 0 +} + +// GetIDType returns the data type of IDs in this checker +func (c *IDsChecker) GetIDType() schemapb.DataType { + return c.idType +} + +// ContainsIDsAtCursors is a batch operation that checks multiple cursor positions at once +// Returns a slice of booleans indicating whether each cursor position exists in the checker +func (c *IDsChecker) ContainsIDsAtCursors(idsA *schemapb.IDs, cursors []int) ([]bool, error) { + if idsA == nil || idsA.GetIdField() == nil { + return nil, fmt.Errorf("idsA is nil or empty") + } + + size := GetSizeOfIDs(idsA) + results := make([]bool, len(cursors)) + + for i, cursor := range cursors { + if cursor < 0 || cursor >= size { + return nil, fmt.Errorf("cursor %d is out of bounds [0, %d)", cursor, size) + } + + exists, err := c.Contains(idsA, cursor) + if err != nil { + return nil, err + } + results[i] = exists + } + + return results, nil +} diff --git a/pkg/util/typeutil/ids_checker_test.go b/pkg/util/typeutil/ids_checker_test.go new file mode 100644 index 0000000000..cb03acf9dd --- /dev/null +++ b/pkg/util/typeutil/ids_checker_test.go @@ -0,0 +1,372 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package typeutil + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" +) + +func TestNewIDsChecker(t *testing.T) { + t.Run("nil IDs", func(t *testing.T) { + checker, err := NewIDsChecker(nil) + assert.NoError(t, err) + assert.True(t, checker.IsEmpty()) + }) + + t.Run("empty IDs", func(t *testing.T) { + ids := &schemapb.IDs{} + checker, err := NewIDsChecker(ids) + assert.NoError(t, err) + assert.True(t, checker.IsEmpty()) + }) + + t.Run("int64 IDs", func(t *testing.T) { + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4, 5}, + }, + }, + } + checker, err := NewIDsChecker(ids) + assert.NoError(t, err) + assert.False(t, checker.IsEmpty()) + assert.Equal(t, 5, checker.Size()) + assert.Equal(t, schemapb.DataType_Int64, checker.GetIDType()) + }) + + t.Run("string IDs", func(t *testing.T) { + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"a", "b", "c", "d"}, + }, + }, + } + checker, err := NewIDsChecker(ids) + assert.NoError(t, err) + assert.False(t, checker.IsEmpty()) + assert.Equal(t, 4, checker.Size()) + assert.Equal(t, schemapb.DataType_VarChar, checker.GetIDType()) + }) + + t.Run("duplicate IDs", func(t *testing.T) { + ids := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 2, 3, 3, 3}, + }, + }, + } + checker, err := NewIDsChecker(ids) + assert.NoError(t, err) + assert.Equal(t, 3, checker.Size()) // Only unique IDs are counted + }) +} + +func TestIDsChecker_Contains(t *testing.T) { + // Create checker with int64 IDs + checkerIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{10, 20, 30, 40, 50}, + }, + }, + } + checker, err := NewIDsChecker(checkerIDs) + require.NoError(t, err) + + t.Run("valid int64 lookups", func(t *testing.T) { + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{10, 15, 20, 25, 30}, + }, + }, + } + + // Test each position + exists, err := checker.Contains(queryIDs, 0) // 10 - should exist + assert.NoError(t, err) + assert.True(t, exists) + + exists, err = checker.Contains(queryIDs, 1) // 15 - should not exist + assert.NoError(t, err) + assert.False(t, exists) + + exists, err = checker.Contains(queryIDs, 2) // 20 - should exist + assert.NoError(t, err) + assert.True(t, exists) + + exists, err = checker.Contains(queryIDs, 3) // 25 - should not exist + assert.NoError(t, err) + assert.False(t, exists) + + exists, err = checker.Contains(queryIDs, 4) // 30 - should exist + assert.NoError(t, err) + assert.True(t, exists) + }) + + t.Run("out of bounds", func(t *testing.T) { + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{10, 20}, + }, + }, + } + + _, err := checker.Contains(queryIDs, -1) + assert.Error(t, err) + + _, err = checker.Contains(queryIDs, 2) + assert.Error(t, err) + }) + + t.Run("type mismatch", func(t *testing.T) { + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"a", "b"}, + }, + }, + } + + _, err := checker.Contains(queryIDs, 0) + assert.Error(t, err) + assert.Contains(t, err.Error(), "type mismatch") + }) + + t.Run("nil query IDs", func(t *testing.T) { + _, err := checker.Contains(nil, 0) + assert.Error(t, err) + }) +} + +func TestIDsChecker_ContainsString(t *testing.T) { + // Create checker with string IDs + checkerIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"apple", "banana", "cherry", "date"}, + }, + }, + } + checker, err := NewIDsChecker(checkerIDs) + require.NoError(t, err) + + t.Run("valid string lookups", func(t *testing.T) { + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"apple", "grape", "banana", "kiwi"}, + }, + }, + } + + exists, err := checker.Contains(queryIDs, 0) // "apple" - should exist + assert.NoError(t, err) + assert.True(t, exists) + + exists, err = checker.Contains(queryIDs, 1) // "grape" - should not exist + assert.NoError(t, err) + assert.False(t, exists) + + exists, err = checker.Contains(queryIDs, 2) // "banana" - should exist + assert.NoError(t, err) + assert.True(t, exists) + + exists, err = checker.Contains(queryIDs, 3) // "kiwi" - should not exist + assert.NoError(t, err) + assert.False(t, exists) + }) +} + +func TestIDsChecker_ContainsAny(t *testing.T) { + checkerIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{10, 20, 30, 40, 50}, + }, + }, + } + checker, err := NewIDsChecker(checkerIDs) + require.NoError(t, err) + + t.Run("some matches", func(t *testing.T) { + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{5, 10, 15, 20, 25, 30}, + }, + }, + } + + indices, err := checker.ContainsAny(queryIDs) + assert.NoError(t, err) + assert.Equal(t, []int{1, 3, 5}, indices) // positions of 10, 20, 30 + }) + + t.Run("no matches", func(t *testing.T) { + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + } + + indices, err := checker.ContainsAny(queryIDs) + assert.NoError(t, err) + assert.Empty(t, indices) + }) + + t.Run("all matches", func(t *testing.T) { + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{10, 20, 30}, + }, + }, + } + + indices, err := checker.ContainsAny(queryIDs) + assert.NoError(t, err) + assert.Equal(t, []int{0, 1, 2}, indices) + }) +} + +func TestIDsChecker_ContainsIDsAtCursors(t *testing.T) { + checkerIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{10, 20, 30, 40, 50}, + }, + }, + } + checker, err := NewIDsChecker(checkerIDs) + require.NoError(t, err) + + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{10, 15, 20, 25, 30, 35}, + }, + }, + } + + t.Run("batch check", func(t *testing.T) { + cursors := []int{0, 1, 2, 4} + results, err := checker.ContainsIDsAtCursors(queryIDs, cursors) + assert.NoError(t, err) + assert.Equal(t, []bool{true, false, true, true}, results) + }) + + t.Run("out of bounds in batch", func(t *testing.T) { + cursors := []int{0, 1, 10} // 10 is out of bounds + _, err := checker.ContainsIDsAtCursors(queryIDs, cursors) + assert.Error(t, err) + }) +} + +func TestIDsChecker_EmptyChecker(t *testing.T) { + checker, err := NewIDsChecker(nil) + require.NoError(t, err) + + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{1, 2, 3}, + }, + }, + } + + // Empty checker should return false for any query + exists, err := checker.Contains(queryIDs, 0) + assert.NoError(t, err) + assert.False(t, exists) + + indices, err := checker.ContainsAny(queryIDs) + assert.NoError(t, err) + assert.Empty(t, indices) +} + +// Benchmark tests +func BenchmarkIDsChecker_Contains(b *testing.B) { + // Create a large checker + data := make([]int64, 10000) + for i := range data { + data[i] = int64(i * 2) // Even numbers + } + checkerIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: data}, + }, + } + checker, _ := NewIDsChecker(checkerIDs) + + // Create query IDs + queryData := make([]int64, 1000) + for i := range queryData { + queryData[i] = int64(i) // Mix of even and odd numbers + } + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: queryData}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + checker.Contains(queryIDs, i%1000) + } +} + +func BenchmarkIDsChecker_ContainsAny(b *testing.B) { + // Create a large checker + data := make([]int64, 10000) + for i := range data { + data[i] = int64(i * 2) // Even numbers + } + checkerIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: data}, + }, + } + checker, _ := NewIDsChecker(checkerIDs) + + // Create query IDs + queryData := make([]int64, 1000) + for i := range queryData { + queryData[i] = int64(i) // Mix of even and odd numbers + } + queryIDs := &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{Data: queryData}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + checker.ContainsAny(queryIDs) + } +} diff --git a/pkg/util/typeutil/schema.go b/pkg/util/typeutil/schema.go index a5c6895194..49db69b33e 100644 --- a/pkg/util/typeutil/schema.go +++ b/pkg/util/typeutil/schema.go @@ -752,31 +752,39 @@ func PrepareResultFieldData(sample []*schemapb.FieldData, topK int64) []*schemap // AppendFieldData appends fields data of specified index from src to dst func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int64) { + dstMap := make(map[int64]*schemapb.FieldData) + for _, fieldData := range dst { + if fieldData != nil { + dstMap[fieldData.FieldId] = fieldData + } + } for i, fieldData := range src { - if dst[i] == nil { - dst[i] = &schemapb.FieldData{ + dstFieldData, ok := dstMap[fieldData.FieldId] + if !ok { + dstFieldData = &schemapb.FieldData{ Type: fieldData.Type, FieldName: fieldData.FieldName, FieldId: fieldData.FieldId, IsDynamic: fieldData.IsDynamic, } + dst[i] = dstFieldData } // assign null data if len(fieldData.GetValidData()) != 0 { - if dst[i].ValidData == nil { - dst[i].ValidData = make([]bool, 0) + if dstFieldData.ValidData == nil { + dstFieldData.ValidData = make([]bool, 0) } valid := fieldData.ValidData[idx] - dst[i].ValidData = append(dst[i].ValidData, valid) + dstFieldData.ValidData = append(dstFieldData.ValidData, valid) } switch fieldType := fieldData.Field.(type) { case *schemapb.FieldData_Scalars: - if dst[i] == nil || dst[i].GetScalars() == nil { - dst[i].Field = &schemapb.FieldData_Scalars{ + if dstFieldData.GetScalars() == nil { + dstFieldData.Field = &schemapb.FieldData_Scalars{ Scalars: &schemapb.ScalarField{}, } } - dstScalar := dst[i].GetScalars() + dstScalar := dstFieldData.GetScalars() switch srcScalar := fieldType.Scalars.Data.(type) { case *schemapb.ScalarField_BoolData: if dstScalar.GetBoolData() == nil { @@ -880,19 +888,14 @@ func AppendFieldData(dst, src []*schemapb.FieldData, idx int64) (appendSize int6 } case *schemapb.FieldData_Vectors: dim := fieldType.Vectors.Dim - if dst[i] == nil || dst[i].GetVectors() == nil { - dst[i] = &schemapb.FieldData{ - Type: fieldData.Type, - FieldName: fieldData.FieldName, - FieldId: fieldData.FieldId, - Field: &schemapb.FieldData_Vectors{ - Vectors: &schemapb.VectorField{ - Dim: dim, - }, + if dstFieldData.GetVectors() == nil { + dstFieldData.Field = &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: dim, }, } } - dstVector := dst[i].GetVectors() + dstVector := dstFieldData.GetVectors() switch srcVector := fieldType.Vectors.Data.(type) { case *schemapb.VectorField_BinaryVector: if dstVector.GetBinaryVector() == nil { @@ -1052,6 +1055,192 @@ func DeleteFieldData(dst []*schemapb.FieldData) { } } +func UpdateFieldData(base, update []*schemapb.FieldData, idx int64) error { + // Create a map for quick lookup of update fields by field ID + updateFieldMap := make(map[string]*schemapb.FieldData) + for _, fieldData := range update { + updateFieldMap[fieldData.FieldName] = fieldData + } + // Iterate through base fields and update if corresponding field exists in update + for _, baseFieldData := range base { + updateFieldData, exists := updateFieldMap[baseFieldData.FieldName] + if !exists { + // No update for this field, keep original value + continue + } + + // Update ValidData if present + if len(updateFieldData.GetValidData()) != 0 { + if len(baseFieldData.GetValidData()) != 0 { + baseFieldData.ValidData[idx] = updateFieldData.ValidData[idx] + } + + // update field data to null,only modify valid data + if !updateFieldData.ValidData[idx] { + continue + } + } + + // Update field data based on type + switch baseFieldType := baseFieldData.Field.(type) { + case *schemapb.FieldData_Scalars: + updateFieldType := updateFieldData.Field.(*schemapb.FieldData_Scalars) + baseScalar := baseFieldType.Scalars + updateScalar := updateFieldType.Scalars + + switch baseScalar.Data.(type) { + case *schemapb.ScalarField_BoolData: + updateData := updateScalar.GetBoolData() + if updateData != nil && int(idx) < len(updateData.Data) { + baseScalar.GetBoolData().Data[idx] = updateData.Data[idx] + } + case *schemapb.ScalarField_IntData: + updateData := updateScalar.GetIntData() + if updateData != nil && int(idx) < len(updateData.Data) { + baseScalar.GetIntData().Data[idx] = updateData.Data[idx] + } + case *schemapb.ScalarField_LongData: + updateData := updateScalar.GetLongData() + if updateData != nil && int(idx) < len(updateData.Data) { + baseScalar.GetLongData().Data[idx] = updateData.Data[idx] + } + case *schemapb.ScalarField_FloatData: + updateData := updateScalar.GetFloatData() + if updateData != nil && int(idx) < len(updateData.Data) { + baseScalar.GetFloatData().Data[idx] = updateData.Data[idx] + } + case *schemapb.ScalarField_DoubleData: + updateData := updateScalar.GetDoubleData() + if updateData != nil && int(idx) < len(updateData.Data) { + baseScalar.GetDoubleData().Data[idx] = updateData.Data[idx] + } + case *schemapb.ScalarField_StringData: + updateData := updateScalar.GetStringData() + if updateData != nil && int(idx) < len(updateData.Data) { + baseScalar.GetStringData().Data[idx] = updateData.Data[idx] + } + case *schemapb.ScalarField_ArrayData: + updateData := updateScalar.GetArrayData() + if updateData != nil && int(idx) < len(updateData.Data) { + baseScalar.GetArrayData().Data[idx] = updateData.Data[idx] + } + case *schemapb.ScalarField_JsonData: + updateData := updateScalar.GetJsonData() + baseData := baseScalar.GetJsonData() + if updateData != nil && int(idx) < len(updateData.Data) { + if baseFieldData.GetIsDynamic() { + // dynamic field is a json with only 1 level nested struct, + // so we need to unmarshal and iterate updateData's key value, and update the baseData's key value + var baseMap map[string]interface{} + var updateMap map[string]interface{} + // unmarshal base and update + if err := json.Unmarshal(baseData.Data[idx], &baseMap); err != nil { + return fmt.Errorf("failed to unmarshal base json: %v", err) + } + if err := json.Unmarshal(updateData.Data[idx], &updateMap); err != nil { + return fmt.Errorf("failed to unmarshal update json: %v", err) + } + // merge + for k, v := range updateMap { + baseMap[k] = v + } + // marshal back + newJSON, err := json.Marshal(baseMap) + if err != nil { + return fmt.Errorf("failed to marshal merged json: %v", err) + } + baseScalar.GetJsonData().Data[idx] = newJSON + } else { + baseScalar.GetJsonData().Data[idx] = updateData.Data[idx] + } + } + default: + log.Error("Not supported scalar field type", zap.String("field type", baseFieldData.Type.String())) + return fmt.Errorf("unsupported scalar field type: %s", baseFieldData.Type.String()) + } + + case *schemapb.FieldData_Vectors: + updateFieldType := updateFieldData.Field.(*schemapb.FieldData_Vectors) + baseVector := baseFieldType.Vectors + updateVector := updateFieldType.Vectors + dim := baseVector.Dim + + switch baseVector.Data.(type) { + case *schemapb.VectorField_BinaryVector: + updateData := updateVector.GetBinaryVector() + if updateData != nil { + baseData := baseVector.GetBinaryVector() + startIdx := idx * (dim / 8) + endIdx := (idx + 1) * (dim / 8) + if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) { + copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx]) + } + } + case *schemapb.VectorField_FloatVector: + updateData := updateVector.GetFloatVector() + if updateData != nil { + baseData := baseVector.GetFloatVector() + startIdx := idx * dim + endIdx := (idx + 1) * dim + if int(endIdx) <= len(updateData.Data) && int(endIdx) <= len(baseData.Data) { + copy(baseData.Data[startIdx:endIdx], updateData.Data[startIdx:endIdx]) + } + } + case *schemapb.VectorField_Float16Vector: + updateData := updateVector.GetFloat16Vector() + if updateData != nil { + baseData := baseVector.GetFloat16Vector() + startIdx := idx * (dim * 2) + endIdx := (idx + 1) * (dim * 2) + if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) { + copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx]) + } + } + case *schemapb.VectorField_Bfloat16Vector: + updateData := updateVector.GetBfloat16Vector() + if updateData != nil { + baseData := baseVector.GetBfloat16Vector() + startIdx := idx * (dim * 2) + endIdx := (idx + 1) * (dim * 2) + if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) { + copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx]) + } + } + case *schemapb.VectorField_SparseFloatVector: + updateData := updateVector.GetSparseFloatVector() + if updateData != nil && int(idx) < len(updateData.Contents) { + baseData := baseVector.GetSparseFloatVector() + if int(idx) < len(baseData.Contents) { + baseData.Contents[idx] = updateData.Contents[idx] + // Update dimension if necessary + if updateData.Dim > baseData.Dim { + baseData.Dim = updateData.Dim + } + } + } + case *schemapb.VectorField_Int8Vector: + updateData := updateVector.GetInt8Vector() + if updateData != nil { + baseData := baseVector.GetInt8Vector() + startIdx := idx * dim + endIdx := (idx + 1) * dim + if int(endIdx) <= len(updateData) && int(endIdx) <= len(baseData) { + copy(baseData[startIdx:endIdx], updateData[startIdx:endIdx]) + } + } + default: + log.Error("Not supported vector field type", zap.String("field type", baseFieldData.Type.String())) + return fmt.Errorf("unsupported vector field type: %s", baseFieldData.Type.String()) + } + default: + log.Error("Not supported field type", zap.String("field type", baseFieldData.Type.String())) + return fmt.Errorf("unsupported field type: %s", baseFieldData.Type.String()) + } + } + + return nil +} + // MergeFieldData appends fields data to dst func MergeFieldData(dst []*schemapb.FieldData, src []*schemapb.FieldData) error { fieldID2Data := make(map[int64]*schemapb.FieldData) diff --git a/pkg/util/typeutil/schema_test.go b/pkg/util/typeutil/schema_test.go index 30297bd47b..c339b1136a 100644 --- a/pkg/util/typeutil/schema_test.go +++ b/pkg/util/typeutil/schema_test.go @@ -18,6 +18,7 @@ package typeutil import ( "encoding/binary" + "encoding/json" "fmt" "math" "math/rand" @@ -2980,3 +2981,443 @@ func TestGetDataIterator(t *testing.T) { }) } } + +func TestUpdateFieldData(t *testing.T) { + const ( + Dim = 8 + BoolFieldName = "BoolField" + Int64FieldName = "Int64Field" + FloatFieldName = "FloatField" + StringFieldName = "StringField" + FloatVectorFieldName = "FloatVectorField" + BoolFieldID = common.StartOfUserFieldID + 1 + Int64FieldID = common.StartOfUserFieldID + 2 + FloatFieldID = common.StartOfUserFieldID + 3 + StringFieldID = common.StartOfUserFieldID + 4 + FloatVectorFieldID = common.StartOfUserFieldID + 5 + ) + + t.Run("update scalar fields", func(t *testing.T) { + // Create base data + baseData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_Bool, + FieldName: BoolFieldName, + FieldId: BoolFieldID, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{true, false, true, false}, + }, + }, + }, + }, + }, + { + Type: schemapb.DataType_Int64, + FieldName: Int64FieldName, + FieldId: Int64FieldID, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + }, + }, + }, + { + Type: schemapb.DataType_VarChar, + FieldName: StringFieldName, + FieldId: StringFieldID, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"a", "b", "c", "d"}, + }, + }, + }, + }, + }, + } + + // Create update data (only update some fields) + updateData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_Bool, + FieldName: BoolFieldName, + FieldId: BoolFieldID, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_BoolData{ + BoolData: &schemapb.BoolArray{ + Data: []bool{false, true, false, true}, // Updated values + }, + }, + }, + }, + }, + { + Type: schemapb.DataType_VarChar, + FieldName: StringFieldName, + FieldId: StringFieldID, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_StringData{ + StringData: &schemapb.StringArray{ + Data: []string{"x", "y", "z", "w"}, // Updated values + }, + }, + }, + }, + }, + // Note: Int64Field is not in update, so it should remain unchanged + } + + // Update index 1 + err := UpdateFieldData(baseData, updateData, 1) + require.NoError(t, err) + + // Check results + // Bool field should be updated at index 1 + assert.Equal(t, true, baseData[0].GetScalars().GetBoolData().Data[1]) + // Int64 field should remain unchanged at index 1 + assert.Equal(t, int64(2), baseData[1].GetScalars().GetLongData().Data[1]) + // String field should be updated at index 1 + assert.Equal(t, "y", baseData[2].GetScalars().GetStringData().Data[1]) + + // Other indices should remain unchanged + assert.Equal(t, true, baseData[0].GetScalars().GetBoolData().Data[0]) + assert.Equal(t, int64(1), baseData[1].GetScalars().GetLongData().Data[0]) + assert.Equal(t, "a", baseData[2].GetScalars().GetStringData().Data[0]) + }) + + t.Run("update vector fields", func(t *testing.T) { + // Create base vector data + baseData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_FloatVector, + FieldName: FloatVectorFieldName, + FieldId: FloatVectorFieldID, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: Dim, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: []float32{ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, // vector 0 + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, // vector 1 + }, + }, + }, + }, + }, + }, + } + + // Create update data + updateData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_FloatVector, + FieldName: FloatVectorFieldName, + FieldId: FloatVectorFieldID, + Field: &schemapb.FieldData_Vectors{ + Vectors: &schemapb.VectorField{ + Dim: Dim, + Data: &schemapb.VectorField_FloatVector{ + FloatVector: &schemapb.FloatArray{ + Data: []float32{ + 100.0, 200.0, 300.0, 400.0, 500.0, 600.0, 700.0, 800.0, // vector 0 + 900.0, 1000.0, 1100.0, 1200.0, 1300.0, 1400.0, 1500.0, 1600.0, // vector 1 + }, + }, + }, + }, + }, + }, + } + + // Update index 1 (second vector) + err := UpdateFieldData(baseData, updateData, 1) + require.NoError(t, err) + + // Check results + vectorData := baseData[0].GetVectors().GetFloatVector().Data + + // First vector should remain unchanged + for i := 0; i < Dim; i++ { + assert.Equal(t, float32(i+1), vectorData[i]) + } + + // Second vector should be updated + for i := 0; i < Dim; i++ { + assert.Equal(t, float32(900+i*100), vectorData[Dim+i]) + } + }) + + t.Run("no update fields", func(t *testing.T) { + baseData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: Int64FieldName, + FieldId: Int64FieldID, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + }, + }, + }, + } + + // Empty update data + updateData := []*schemapb.FieldData{} + + // Update should succeed but change nothing + err := UpdateFieldData(baseData, updateData, 1) + require.NoError(t, err) + + // Data should remain unchanged + assert.Equal(t, int64(2), baseData[0].GetScalars().GetLongData().Data[1]) + }) + + t.Run("update with ValidData", func(t *testing.T) { + baseData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: Int64FieldName, + FieldId: Int64FieldID, + ValidData: []bool{true, true, false, true}, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{1, 2, 3, 4}, + }, + }, + }, + }, + }, + } + + updateData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_Int64, + FieldName: Int64FieldName, + FieldId: Int64FieldID, + ValidData: []bool{false, false, true, false}, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_LongData{ + LongData: &schemapb.LongArray{ + Data: []int64{10, 20, 30, 40}, + }, + }, + }, + }, + }, + } + + // Update index 1 + err := UpdateFieldData(baseData, updateData, 1) + require.NoError(t, err) + err = UpdateFieldData(baseData, updateData, 2) + require.NoError(t, err) + + // Check that ValidData was updated + assert.Equal(t, false, baseData[0].ValidData[1]) + // Check that data was updated + assert.Equal(t, int64(2), baseData[0].GetScalars().GetLongData().Data[1]) + assert.Equal(t, int64(30), baseData[0].GetScalars().GetLongData().Data[2]) + }) + + t.Run("update dynamic json field", func(t *testing.T) { + // Create base data with dynamic JSON field + baseData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_JSON, + FieldName: "json_field", + FieldId: 1, + IsDynamic: true, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: [][]byte{ + []byte(`{"key1": "value1", "key2": 123}`), + []byte(`{"key3": true, "key4": 456.789}`), + }, + }, + }, + }, + }, + }, + } + + // Create update data + updateData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_JSON, + FieldName: "json_field", + FieldId: 1, + IsDynamic: true, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: [][]byte{ + []byte(`{"key2": 999, "key5": "new_value"}`), + []byte(`{"key4": 111.222, "key6": false}`), + }, + }, + }, + }, + }, + }, + } + + // Test updating first row + err := UpdateFieldData(baseData, updateData, 0) + require.NoError(t, err) + + // Verify first row was correctly merged + firstRow := baseData[0].GetScalars().GetJsonData().Data[0] + var result map[string]interface{} + err = json.Unmarshal(firstRow, &result) + require.NoError(t, err) + + // Check merged values + assert.Equal(t, "value1", result["key1"]) + assert.Equal(t, float64(999), result["key2"]) // Updated value + assert.Equal(t, "new_value", result["key5"]) // New value + + // Test updating second row + err = UpdateFieldData(baseData, updateData, 1) + require.NoError(t, err) + + // Verify second row was correctly merged + secondRow := baseData[0].GetScalars().GetJsonData().Data[1] + err = json.Unmarshal(secondRow, &result) + require.NoError(t, err) + + // Check merged values + assert.Equal(t, true, result["key3"]) + assert.Equal(t, float64(111.222), result["key4"]) // Updated value + assert.Equal(t, false, result["key6"]) // New value + }) + + t.Run("update non-dynamic json field", func(t *testing.T) { + // Create base data with non-dynamic JSON field + baseData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_JSON, + FieldName: "json_field", + FieldId: 1, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: [][]byte{ + []byte(`{"key1": "value1"}`), + }, + }, + }, + }, + }, + }, + } + + // Create update data + updateData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_JSON, + FieldName: "json_field", + FieldId: 1, + IsDynamic: false, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: [][]byte{ + []byte(`{"key2": "value2"}`), + }, + }, + }, + }, + }, + }, + } + + // Test updating + err := UpdateFieldData(baseData, updateData, 0) + require.NoError(t, err) + + // For non-dynamic fields, the update should completely replace the old value + result := baseData[0].GetScalars().GetJsonData().Data[0] + assert.Equal(t, []byte(`{"key2": "value2"}`), result) + }) + + t.Run("invalid json data", func(t *testing.T) { + // Create base data with invalid JSON + baseData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_JSON, + FieldName: "json_field", + FieldId: 1, + IsDynamic: true, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: [][]byte{ + []byte(`invalid json`), + }, + }, + }, + }, + }, + }, + } + + updateData := []*schemapb.FieldData{ + { + Type: schemapb.DataType_JSON, + FieldName: "json_field", + FieldId: 1, + IsDynamic: true, + Field: &schemapb.FieldData_Scalars{ + Scalars: &schemapb.ScalarField{ + Data: &schemapb.ScalarField_JsonData{ + JsonData: &schemapb.JSONArray{ + Data: [][]byte{ + []byte(`{"key": "value"}`), + }, + }, + }, + }, + }, + }, + } + + // Test updating with invalid base JSON + err := UpdateFieldData(baseData, updateData, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal base json") + + // Create base data with valid JSON but invalid update + baseData[0].GetScalars().GetJsonData().Data[0] = []byte(`{"key": "value"}`) + updateData[0].GetScalars().GetJsonData().Data[0] = []byte(`invalid json`) + + // Test updating with invalid update JSON + err = UpdateFieldData(baseData, updateData, 0) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal update json") + }) +} diff --git a/tests/go_client/common/response_checker.go b/tests/go_client/common/response_checker.go index 749603b701..a22e4d86a2 100644 --- a/tests/go_client/common/response_checker.go +++ b/tests/go_client/common/response_checker.go @@ -289,6 +289,22 @@ func CheckSearchIteratorResult(ctx context.Context, t *testing.T, itr client.Sea } } +// check expected columns should be contains in actual columns +func CheckPartialResult(t *testing.T, expColumns []column.Column, actualColumns []column.Column) { + for _, expColumn := range expColumns { + exist := false + for _, actualColumn := range actualColumns { + if expColumn.Name() == actualColumn.Name() && expColumn.Type() != entity.FieldTypeJSON { + exist = true + EqualColumn(t, expColumn, actualColumn) + } + } + if !exist { + log.Error("CheckQueryResult actualColumns no column", zap.String("name", expColumn.Name())) + } + } +} + // GenColumnDataOption -- create column data -- type checkIndexOpt struct { state index.IndexState diff --git a/tests/go_client/go.sum b/tests/go_client/go.sum index 4609a4ed92..431c5ab1b8 100644 --- a/tests/go_client/go.sum +++ b/tests/go_client/go.sum @@ -318,8 +318,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfr github.com/mediocregopher/radix/v3 v3.4.2/go.mod h1:8FL3F6UQRXHXIBSPUs5h0RybMF8i4n7wVopoX3x7Bv8= github.com/microcosm-cc/bluemonday v1.0.2/go.mod h1:iVP4YcDBq+n/5fb23BhYFvIMq/leAFZyRl6bYmGDlGc= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= -github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6 h1:qTBOTsZ3OwEXkrHRqPn562ddkDqeToIY6CstLIaVQYs= -github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807040333-531631e7fce6/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17 h1:zyrKuc0rwT5xWIFkZr/bFWXXYbYvSBMT3iFITnaR8IE= github.com/milvus-io/milvus-proto/go-api/v2 v2.6.1-0.20250807065533-ebdc11f5df17/go.mod h1:/6UT4zZl6awVeXLeE7UGDWZvXj3IWkRsh3mqsn0DiAs= github.com/milvus-io/milvus/pkg/v2 v2.0.0-20250319085209-5a6b4e56d59e h1:VCr43pG4efacDbM4au70fh8/5hNTftoWzm1iEumvDWM= diff --git a/tests/go_client/testcases/helper/data_helper.go b/tests/go_client/testcases/helper/data_helper.go index a2f6721de7..b5ca79c05b 100644 --- a/tests/go_client/testcases/helper/data_helper.go +++ b/tests/go_client/testcases/helper/data_helper.go @@ -384,6 +384,10 @@ func GetAllFunctionsOutputFields(schema *entity.Schema) []string { return outputFields } +func GenColumnDataWithOption(fieldType entity.FieldType, option GenDataOption) column.Column { + return GenColumnData(option.nb, fieldType, option) +} + // GenColumnData GenColumnDataOption except dynamic column func GenColumnData(nb int, fieldType entity.FieldType, option GenDataOption) column.Column { dim := option.dim @@ -655,6 +659,10 @@ func GenColumnDataWithFp32VecConversion(nb int, fieldType entity.FieldType, opti } } +func GenDynamicColumnDataWithOption(option GenDataOption) []column.Column { + return GenDynamicColumnData(option.start, option.nb) +} + func GenDynamicColumnData(start int, nb int) []column.Column { type ListStruct struct { List []int64 `json:"list" milvus:"name:list"` diff --git a/tests/go_client/testcases/update_test.go b/tests/go_client/testcases/update_test.go new file mode 100644 index 0000000000..a34a42deea --- /dev/null +++ b/tests/go_client/testcases/update_test.go @@ -0,0 +1,206 @@ +package testcases + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/milvus-io/milvus/client/v2/column" + "github.com/milvus-io/milvus/client/v2/entity" + client "github.com/milvus-io/milvus/client/v2/milvusclient" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/tests/go_client/common" + hp "github.com/milvus-io/milvus/tests/go_client/testcases/helper" +) + +func TestUpdatePartialFields(t *testing.T) { + /* + 1. prepare create -> insert -> index -> load -> query + 2. partial update existing entities -> data updated -> query and verify + */ + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + // connect + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert [0, 3000) -> flush -> index -> load + // create -> insert -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.AllFields), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption()) + prepare.FlushData(ctx, t, mc, schema.CollectionName) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + + genPkWithSingleScalarField := func(option *hp.GenDataOption) ([]column.Column, []column.Column) { + log.Info("genPkWithSingleScalarField") + columns := make([]column.Column, 0, 2) + columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeInt64, *option)) + columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeFloat, *option)) + return columns, nil + } + + genPkWithSinglVectorField := func(option *hp.GenDataOption) ([]column.Column, []column.Column) { + log.Info("genPkWithSinglVectorField") + columns := make([]column.Column, 0, 2) + columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeInt64, *option)) + columns = append(columns, hp.GenColumnDataWithOption(entity.FieldTypeFloatVector, *option)) + return columns, nil + } + + updateNb := 200 + for _, genColumnsFunc := range []func(*hp.GenDataOption) ([]column.Column, []column.Column){genPkWithSingleScalarField, genPkWithSinglVectorField} { + // perform partial update operation for existing entities [0, 200) -> query and verify + columns, dynamicColumns := genColumnsFunc(hp.TNewDataOption().TWithNb(updateNb).TWithStart(0)) + updateRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(columns...).WithColumns(dynamicColumns...).WithPartialUpdate(true)) + common.CheckErr(t, err, true) + require.EqualValues(t, updateNb, updateRes.UpsertCount) + + expr := fmt.Sprintf("%s < %d", common.DefaultInt64FieldName, updateNb) + resSet, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(expr).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + common.CheckPartialResult(t, append(columns, hp.MergeColumnsToDynamic(updateNb, dynamicColumns, common.DefaultDynamicFieldName)), resSet.Fields) + } +} + +func TestPartialUpdateDynamicField(t *testing.T) { + // enable dynamic field and perform partial update operations on dynamic columns + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // create -> insert [0, 3000) -> flush -> index -> load + prepare, schema := hp.CollPrepare.CreateCollection(ctx, t, mc, hp.NewCreateCollectionParams(hp.Int64Vec), hp.TNewFieldsOption(), hp.TNewSchemaOption().TWithEnableDynamicField(true)) + prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema), hp.TNewDataOption()) + prepare.CreateIndex(ctx, t, mc, hp.TNewIndexParams(schema)) + prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName)) + time.Sleep(time.Second * 4) + // verify that dynamic field exists + testNb := 10 + resSet, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s < %d", common.DefaultDynamicNumberField, testNb)). + WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + require.Equal(t, testNb, resSet.GetColumn(common.DefaultDynamicFieldName).Len()) + + // 1. query and gets empty + targetPk := int64(20000) + resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)). + WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + require.Equal(t, 0, resSet.GetColumn(common.DefaultDynamicFieldName).Len()) + + // 2. perform partial update operation for existing pk with dynamic column [a=1] + dynamicColumnA := column.NewColumnInt32(common.DefaultDynamicNumberField, []int32{1}) + vecColumn := hp.GenColumnData(1, entity.FieldTypeFloatVector, *hp.TNewDataOption()) + pkColumnA := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{targetPk}) + _, err = mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumnA, dynamicColumnA, vecColumn).WithPartialUpdate(true)) + common.CheckErr(t, err, true) + time.Sleep(time.Second * 4) + resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)). + WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + require.Equal(t, 1, resSet.GetColumn(common.DefaultDynamicFieldName).Len()) + common.EqualColumn(t, hp.MergeColumnsToDynamic(1, []column.Column{dynamicColumnA}, common.DefaultDynamicFieldName), resSet.GetColumn(common.DefaultDynamicFieldName)) + + // 3. perform partial update operation for existing pk with dynamic column [b=true] + dynamicColumnB := column.NewColumnBool(common.DefaultDynamicBoolField, []bool{true}) + _, err = mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumnA, dynamicColumnB).WithPartialUpdate(true)) + common.CheckErr(t, err, true) + time.Sleep(time.Second * 4) + resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)). + WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + require.Equal(t, 1, resSet.GetColumn(common.DefaultDynamicFieldName).Len()) + common.EqualColumn(t, hp.MergeColumnsToDynamic(1, []column.Column{dynamicColumnA, dynamicColumnB}, common.DefaultDynamicFieldName), resSet.GetColumn(common.DefaultDynamicFieldName)) + + // 4. perform partial update operation for existing pk with dynamic column [a=2, b=false] + dynamicColumnA = column.NewColumnInt32(common.DefaultDynamicNumberField, []int32{2}) + _, err = mc.Upsert(ctx, client.NewColumnBasedInsertOption(schema.CollectionName).WithColumns(pkColumnA, dynamicColumnA).WithPartialUpdate(true)) + common.CheckErr(t, err, true) + time.Sleep(time.Second * 4) + resSet, err = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s == %d", common.DefaultInt64FieldName, targetPk)). + WithOutputFields(common.DefaultDynamicFieldName).WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + require.Equal(t, 1, resSet.GetColumn(common.DefaultDynamicFieldName).Len()) + common.EqualColumn(t, hp.MergeColumnsToDynamic(1, []column.Column{dynamicColumnA, dynamicColumnB}, common.DefaultDynamicFieldName), resSet.GetColumn(common.DefaultDynamicFieldName)) +} + +func TestUpdateNullableFieldBehavior(t *testing.T) { + /* + Test nullable field behavior for Update operation: + 1. Insert data with nullable field having a value + 2. Update the same entity without providing the nullable field + 3. Verify that the nullable field retains its original value (not set to null) + */ + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // Create collection with nullable field using custom schema + collName := common.GenRandomString("update_nullable", 6) + + // Create fields including nullable field + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim) + nullableField := entity.NewField().WithName("nullable_varchar").WithDataType(entity.FieldTypeVarChar).WithMaxLength(100).WithNullable(true) + + fields := []*entity.Field{pkField, vecField, nullableField} + schema := hp.GenSchema(hp.TNewSchemaOption().TWithName(collName).TWithDescription("test nullable field behavior for update").TWithFields(fields)) + + // Create collection using schema + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema)) + common.CheckErr(t, err, true) + + // Cleanup + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Second*10) + defer cancel() + err := mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + + // Insert initial data with nullable field having a value + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2, 3}) + vecColumn := hp.GenColumnData(3, entity.FieldTypeFloatVector, *hp.TNewDataOption()) + nullableColumn := column.NewColumnVarChar("nullable_varchar", []string{"original_1", "original_2", "original_3"}) + + _, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn, nullableColumn)) + common.CheckErr(t, err, true) + + // Use prepare pattern for remaining operations + prepare := &hp.CollectionPrepare{} + + // Flush data + prepare.FlushData(ctx, t, mc, collName) + + // Create index for vector field + indexParams := hp.TNewIndexParams(schema) + prepare.CreateIndex(ctx, t, mc, indexParams) + + // Load collection + loadParams := hp.NewLoadParams(collName) + prepare.Load(ctx, t, mc, loadParams) + + // Wait for loading to complete + time.Sleep(time.Second * 5) + + // Update entities without providing nullable field (should retain original values) + updatePkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2}) + updateVecColumn := hp.GenColumnData(2, entity.FieldTypeFloatVector, *hp.TNewDataOption().TWithStart(100)) + + updateRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(updatePkColumn, updateVecColumn).WithPartialUpdate(true)) + common.CheckErr(t, err, true) + require.EqualValues(t, 2, updateRes.UpsertCount) + + // Wait for consistency + time.Sleep(time.Second * 3) + + // Query to verify nullable field retains original values + resSet, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("%s in [1, 2]", common.DefaultInt64FieldName)).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // Verify results + require.Equal(t, 2, resSet.GetColumn("nullable_varchar").Len()) + nullableResults := resSet.GetColumn("nullable_varchar").(*column.ColumnVarChar).Data() + require.Equal(t, "original_1", nullableResults[0]) + require.Equal(t, "original_2", nullableResults[1]) +} diff --git a/tests/go_client/testcases/upsert_test.go b/tests/go_client/testcases/upsert_test.go index cf64c71beb..883d5b8aec 100644 --- a/tests/go_client/testcases/upsert_test.go +++ b/tests/go_client/testcases/upsert_test.go @@ -1,6 +1,7 @@ package testcases import ( + "context" "fmt" "strconv" "testing" @@ -653,3 +654,91 @@ func TestUpsertWithoutLoading(t *testing.T) { func TestUpsertPartitionKeyCollection(t *testing.T) { t.Skip("waiting gen partition key field") } + +func TestUpsertNullableFieldBehavior(t *testing.T) { + /* + Test nullable field behavior for Upsert operation: + 1. Insert data with nullable field having a value + 2. Upsert the same entity without providing the nullable field + 3. Verify that the nullable field is set to null (upsert replaces all fields) + */ + ctx := hp.CreateContext(t, time.Second*common.DefaultTimeout) + mc := hp.CreateDefaultMilvusClient(ctx, t) + + // Create collection with nullable field using custom schema + collName := common.GenRandomString("upsert_nullable", 6) + + // Create fields including nullable field + pkField := entity.NewField().WithName(common.DefaultInt64FieldName).WithDataType(entity.FieldTypeInt64).WithIsPrimaryKey(true) + vecField := entity.NewField().WithName(common.DefaultFloatVecFieldName).WithDataType(entity.FieldTypeFloatVector).WithDim(common.DefaultDim) + nullableField := entity.NewField().WithName("nullable_varchar").WithDataType(entity.FieldTypeVarChar).WithMaxLength(100).WithNullable(true) + + fields := []*entity.Field{pkField, vecField, nullableField} + schema := hp.GenSchema(hp.TNewSchemaOption().TWithName(collName).TWithDescription("test nullable field behavior for upsert").TWithFields(fields)) + + // Create collection using schema + err := mc.CreateCollection(ctx, client.NewCreateCollectionOption(collName, schema)) + common.CheckErr(t, err, true) + + // Cleanup + t.Cleanup(func() { + ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), time.Second*10) + defer cancel() + err := mc.DropCollection(ctx, client.NewDropCollectionOption(collName)) + common.CheckErr(t, err, true) + }) + + // Insert initial data with nullable field having a value + pkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2, 3}) + vecColumn := hp.GenColumnData(3, entity.FieldTypeFloatVector, *hp.TNewDataOption()) + nullableColumn := column.NewColumnVarChar("nullable_varchar", []string{"original_1", "original_2", "original_3"}) + + _, err = mc.Insert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(pkColumn, vecColumn, nullableColumn)) + common.CheckErr(t, err, true) + + // Use prepare pattern for remaining operations + prepare := &hp.CollectionPrepare{} + + // Flush data + prepare.FlushData(ctx, t, mc, collName) + + // Create index for vector field + indexParams := hp.TNewIndexParams(schema) + prepare.CreateIndex(ctx, t, mc, indexParams) + + // Load collection + loadParams := hp.NewLoadParams(collName) + prepare.Load(ctx, t, mc, loadParams) + + // Wait for loading to complete + time.Sleep(time.Second * 5) + + // Upsert entities without providing nullable field (should set to null) + upsertPkColumn := column.NewColumnInt64(common.DefaultInt64FieldName, []int64{1, 2}) + upsertVecColumn := hp.GenColumnData(2, entity.FieldTypeFloatVector, *hp.TNewDataOption().TWithStart(100)) + + upsertRes, err := mc.Upsert(ctx, client.NewColumnBasedInsertOption(collName).WithColumns(upsertPkColumn, upsertVecColumn)) + common.CheckErr(t, err, true) + require.EqualValues(t, 2, upsertRes.UpsertCount) + + // Wait for consistency + time.Sleep(time.Second * 3) + + // Query to verify nullable field is set to null + resSet, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("%s in [1, 2]", common.DefaultInt64FieldName)).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + // Verify results - nullable field should be null + require.Equal(t, 2, resSet.GetColumn("nullable_varchar").Len()) + nullableResults := resSet.GetColumn("nullable_varchar").(*column.ColumnVarChar).Data() + require.Equal(t, "", nullableResults[0]) // null value is represented as empty string + require.Equal(t, "", nullableResults[1]) // null value is represented as empty string + + // Query entity that was not upserted to verify original value is preserved + resSet3, err := mc.Query(ctx, client.NewQueryOption(collName).WithFilter(fmt.Sprintf("%s == 3", common.DefaultInt64FieldName)).WithOutputFields("*").WithConsistencyLevel(entity.ClStrong)) + common.CheckErr(t, err, true) + + require.Equal(t, 1, resSet3.GetColumn("nullable_varchar").Len()) + nullableResult3 := resSet3.GetColumn("nullable_varchar").(*column.ColumnVarChar).Data() + require.Equal(t, "original_3", nullableResult3[0]) +} diff --git a/tests/integration/util_collection.go b/tests/integration/util_collection.go index 284dcf5cd1..7013df650c 100644 --- a/tests/integration/util_collection.go +++ b/tests/integration/util_collection.go @@ -74,6 +74,10 @@ func (s *MiniClusterSuite) InsertAndFlush(ctx context.Context, dbName, collectio func (s *MiniClusterSuite) CreateCollectionWithConfiguration(ctx context.Context, cfg *CreateCollectionConfig) { schema := ConstructSchema(cfg.CollectionName, cfg.Dim, true) + s.CreateCollection(ctx, cfg, schema) +} + +func (s *MiniClusterSuite) CreateCollection(ctx context.Context, cfg *CreateCollectionConfig, schema *schemapb.CollectionSchema) { marshaledSchema, err := proto.Marshal(schema) s.NoError(err) s.NotNil(marshaledSchema) diff --git a/tests/restful_client_v2/testcases/test_partial_update.py b/tests/restful_client_v2/testcases/test_partial_update.py new file mode 100644 index 0000000000..3a1405c4bd --- /dev/null +++ b/tests/restful_client_v2/testcases/test_partial_update.py @@ -0,0 +1,1351 @@ +import random +from sklearn import preprocessing +import numpy as np +import sys +import json +import time +from utils import constant +from utils.utils import gen_collection_name, get_sorted_distance, patch_faker_text, en_vocabularies_distribution, \ + zh_vocabularies_distribution +from utils.util_log import test_log as logger +import pytest +from base.testbase import TestBase +from utils.utils import (gen_unique_str, get_data_by_payload, get_common_fields_by_data, gen_vector, analyze_documents) +from pymilvus import ( + FieldSchema, CollectionSchema, DataType, + Collection, utility +) +from faker import Faker +import re + +Faker.seed(19530) +fake_en = Faker("en_US") +fake_zh = Faker("zh_CN") + +patch_faker_text(fake_en, en_vocabularies_distribution) +patch_faker_text(fake_zh, zh_vocabularies_distribution) + + +@pytest.mark.L0 +class TestPartialUpdate(TestBase): + + @pytest.mark.parametrize("id_type", ["Int64", "VarChar"]) + def test_partial_update_basic(self, id_type): + """ + Test basic partial update functionality + 1. Create collection + 2. Insert initial data + 3. Partial update with only some fields + 4. Verify only updated fields are changed + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": f"{id_type}", "isPrimary": True, + "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert initial data + nb = 10 + initial_data = [] + for i in range(nb): + tmp = { + "book_id": i if id_type == "Int64" else f"{i}", + "user_id": i, + "word_count": i * 100, + "book_describe": f"original_book_{i}", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + initial_data.append(tmp) + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + assert rsp['data']['insertCount'] == nb + + c = Collection(name) + c.flush() + time.sleep(3) # Wait for data to be available + + # Partial update - only update book_describe field + partial_update_data = [] + for i in range(nb): + tmp = { + "book_id": i if id_type == "Int64" else f"{i}", + "book_describe": f"updated_book_{i}", # Only update this field + } + partial_update_data.append(tmp) + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True # Enable partial update + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + + # Verify partial update worked correctly + if id_type == "Int64": + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id >= 0"}) + else: + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id >= '0'"}) + + assert rsp['code'] == 0 + assert len(rsp['data']) == nb + + for data in rsp['data']: + book_id = int(data['book_id']) + # book_describe should be updated + assert data['book_describe'] == f"updated_book_{book_id}" + # Other fields should remain unchanged + assert data['user_id'] == book_id + assert data['word_count'] == book_id * 100 + + logger.info("Partial update basic test passed") + + @pytest.mark.parametrize("id_type", ["Int64", "VarChar"]) + def test_partial_update_multiple_fields(self, id_type): + """ + Test partial update with multiple fields + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": f"{id_type}", "isPrimary": True, + "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "rating", "dataType": "Double", "elementTypeParams": {}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert initial data + nb = 10 + initial_data = [] + for i in range(nb): + tmp = { + "book_id": i if id_type == "Int64" else f"{i}", + "user_id": i, + "word_count": i * 100, + "book_describe": f"original_book_{i}", + "rating": 3.5, + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + initial_data.append(tmp) + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Partial update - update multiple fields + partial_update_data = [] + for i in range(nb): + tmp = { + "book_id": i if id_type == "Int64" else f"{i}", + "book_describe": f"updated_book_{i}", + "rating": 4.5, # Update rating + "word_count": i * 200, # Update word count + } + partial_update_data.append(tmp) + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + + # Verify partial update + if id_type == "Int64": + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id >= 0"}) + else: + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id >= '0'"}) + + assert rsp['code'] == 0 + for data in rsp['data']: + book_id = int(data['book_id']) + # Updated fields + assert data['book_describe'] == f"updated_book_{book_id}" + assert data['rating'] == 4.5 + assert data['word_count'] == book_id * 200 + # Unchanged field + assert data['user_id'] == book_id + + logger.info("Partial update multiple fields test passed") + + def test_partial_update_new_record_missing_fields(self): + """ + Test partial update behavior with new records missing required fields (should fail) + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Try partial update on non-existent record with missing required fields (should fail) + partial_update_data = [{ + "book_id": 999, + "book_describe": "new_book_description" + # Missing required fields: user_id, text_emb + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + # Should fail because required fields are missing for new record insertion + assert rsp['code'] != 0 + assert "fieldSchema" in rsp['message'] or "field" in rsp['message'].lower() + logger.info(f"Expected failure for missing fields: {rsp['message']}") + + def test_partial_update_new_record_with_full_fields(self): + """ + Test partial update behavior with new records when all required fields are provided (should succeed) + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Partial update on non-existent record with all required fields (should succeed as insert) + partial_update_data = [{ + "book_id": 999, + "user_id": 999, + "book_describe": "new_book_description", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 1 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Verify the new record was inserted + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id == 999"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + assert rsp['data'][0]['book_id'] == 999 + assert rsp['data'][0]['user_id'] == 999 + assert rsp['data'][0]['book_describe'] == "new_book_description" + + logger.info("Partial update with full fields for new record test passed") + + def test_partial_update_with_vector_field(self): + """ + Test partial update including vector field + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert initial data + nb = 5 + initial_data = [] + for i in range(nb): + tmp = { + "book_id": i, + "user_id": i, + "book_describe": f"original_book_{i}", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + initial_data.append(tmp) + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Partial update with vector field + partial_update_data = [] + for i in range(nb): + tmp = { + "book_id": i, + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + partial_update_data.append(tmp) + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + + # Verify update + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id >= 0"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == nb + + logger.info("Partial update with vector field test passed") + + def test_partial_update_mixed_scenario(self): + """ + Test partial update with mixed scenario: some records exist, some don't + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert some initial data + nb = 5 + initial_data = [] + for i in range(nb): + tmp = { + "book_id": i, + "user_id": i, + "word_count": i * 100, + "book_describe": f"original_book_{i}", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + initial_data.append(tmp) + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Step 1: Update existing records (partial fields only) + update_data = [] + for i in range(nb): + tmp = { + "book_id": i, + "book_describe": f"updated_book_{i}", # Only update description + } + update_data.append(tmp) + + payload = { + "collectionName": name, + "data": update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 5 # 5 updates + + c.flush() + time.sleep(3) + + # Step 2: Insert new records (all required fields) + new_records_data = [] + for i in range(10, 13): + tmp = { + "book_id": i, + "user_id": i + 100, + "word_count": i * 50, + "book_describe": f"new_book_{i}", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + new_records_data.append(tmp) + + payload = { + "collectionName": name, + "data": new_records_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 3 # 3 inserts + + c.flush() + time.sleep(3) + + # Verify existing records were updated (partial update) + for i in range(nb): + rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"book_id == {i}"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + data = rsp['data'][0] + + # Updated field + assert data['book_describe'] == f"updated_book_{i}" + # Unchanged fields + assert data['user_id'] == i + assert data['word_count'] == i * 100 + + # Verify new records were inserted (full insert) + for i in range(10, 13): + rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"book_id == {i}"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + data = rsp['data'][0] + + assert data['book_describe'] == f"new_book_{i}" + assert data['user_id'] == i + 100 + assert data['word_count'] == i * 50 + + logger.info("Mixed partial update scenario test passed") + + def test_partial_update_with_auto_id(self): + """ + Test partial update with autoID primary key - should fail as autoID is not supported for upsert + """ + # Create collection with autoID primary key + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "autoId": True, + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert initial data (without providing book_id as it's autoID) + nb = 3 + initial_data = [] + for i in range(nb): + tmp = { + "user_id": i, + "book_describe": f"original_book_{i}", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + initial_data.append(tmp) + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Get the auto-generated IDs before partial update + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "user_id >= 0", "outputFields": ["book_id", "user_id", "book_describe"]}) + assert rsp['code'] == 0 + assert len(rsp['data']) == nb + + original_ids = [data['book_id'] for data in rsp['data']] + original_data_map = {data['user_id']: data for data in rsp['data']} + + # Partial update existing records using their auto-generated IDs + # When autoID=true, partial update should generate NEW IDs for existing records + partial_update_data = [] + for i, book_id in enumerate(original_ids): + tmp = { + "book_id": book_id, + "book_describe": f"updated_book_{i}", # Only update description + } + partial_update_data.append(tmp) + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 3 # 3 updates + c.flush() + time.sleep(3) + + # Critical verification: old IDs should no longer exist + for old_id in original_ids: + rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"book_id == {old_id}"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 0, f"Old ID {old_id} should not exist after partial update with autoID=true" + + # Verify updated records have NEW auto-generated IDs + for i in range(nb): + rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"user_id == {i}"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + data = rsp['data'][0] + + # Should have updated description + assert data['book_describe'] == f"updated_book_{i}" + # Should have same user_id (identifies the record) + assert data['user_id'] == i + # Should have NEW book_id (different from original) + assert data['book_id'] not in original_ids, f"New ID {data['book_id']} should be different from original IDs {original_ids}" + + # Verify total count is still correct (3 updated) + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "user_id >= 0"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 3 + + logger.info("Partial update with autoID test passed - verified new IDs generated for updated records") + + + """ + Test detailed behavior of partial update with autoID: old record deletion and new record insertion + """ + # Create collection with autoID primary key + name = gen_collection_name() + dim = 64 + payload = { + "collectionName": name, + "schema": { + "autoId": True, + "fields": [ + {"fieldName": "id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "name", "dataType": "VarChar", "elementTypeParams": {"max_length": "100"}}, + {"fieldName": "age", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "vector", "indexName": "vector_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert one record + initial_data = [{ + "name": "Alice", + "age": 25, + "vector": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + }] + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Get the original record + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "age > 0"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + original_record = rsp['data'][0] + original_id = original_record['id'] + + logger.info(f"Original record: ID={original_id}, name={original_record['name']}, age={original_record['age']}") + + # Perform partial update using the original ID + partial_update_data = [{ + "id": original_id, + "name": "Alice Updated" # Only update name, age should remain unchanged + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 1 + + c.flush() + time.sleep(3) + + # Verify the original ID no longer exists + rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"id == {original_id}"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 0, f"Original ID {original_id} should be deleted after partial update with autoID=true" + + # Verify there's still exactly one record with updated data + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "age > 0"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + + updated_record = rsp['data'][0] + new_id = updated_record['id'] + + logger.info(f"Updated record: ID={new_id}, name={updated_record['name']}, age={updated_record['age']}") + + # Verify the record has a new ID and updated fields + assert new_id != original_id, f"New ID {new_id} should be different from original ID {original_id}" + assert updated_record['name'] == "Alice Updated", "Name should be updated" + assert updated_record['age'] == 25, "Age should remain unchanged (inherited from original record)" + + logger.info("Detailed autoID partial update behavior test passed") + + def test_partial_update_auto_id_only_specified_fields_updated(self): + """ + Test that only specified fields are updated in partial update with autoID, others remain from original + """ + # Create collection + name = gen_collection_name() + dim = 64 + payload = { + "collectionName": name, + "schema": { + "autoId": True, + "fields": [ + {"fieldName": "id", "dataType": "Int64", "isPrimary": True,"elementTypeParams": {}}, + {"fieldName": "field1", "dataType": "VarChar", "elementTypeParams": {"max_length": "100"}}, + {"fieldName": "field2", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "field3", "dataType": "Double", "elementTypeParams": {}}, + {"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "vector", "indexName": "vector_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert original record with all fields + original_vector = preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + initial_data = [{ + "field1": "original_value1", + "field2": 100, + "field3": 3.14, + "vector": original_vector + }] + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Get original record + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "field2 > 0"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + original_record = rsp['data'][0] + original_id = original_record['id'] + + # Partial update - only update field1, others should remain unchanged + partial_update_data = [{ + "id": original_id, + "field1": "updated_value1" # Only update field1 + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + + c.flush() + time.sleep(3) + + # Verify updated record + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "field2 > 0"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + + updated_record = rsp['data'][0] + + # Verify new ID generated + assert updated_record['id'] != original_id, "Should have new autoID" + # Verify field1 was updated + assert updated_record['field1'] == "updated_value1", "field1 should be updated" + # Verify other fields remained unchanged + assert updated_record['field2'] == 100, "field2 should remain unchanged" + assert updated_record['field3'] == 3.14, "field3 should remain unchanged" + # Note: vector field should also remain unchanged, but might need special handling in verification + + logger.info("Partial update with autoID - only specified fields updated test passed") + + def test_partial_update_with_default_and_nullable_fields(self): + """ + Test partial update with default values and nullable fields for new records + """ + # Create collection with default value and nullable fields + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "word_count", "dataType": "Int64", "elementTypeParams": {}, "defaultValue": 1000}, # Default value + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "rating", "dataType": "Double", "elementTypeParams": {}, "nullable": True}, # Nullable field + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert initial data + nb = 3 + initial_data = [] + for i in range(nb): + tmp = { + "book_id": i, + "user_id": i, + "word_count": i * 100, + "book_describe": f"original_book_{i}", + "rating": 3.5 + i * 0.5, + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + initial_data.append(tmp) + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Test 1: Partial update existing records only + partial_update_data = [] + for i in range(nb): + tmp = { + "book_id": i, + "book_describe": f"updated_book_{i}", # Only update description + } + partial_update_data.append(tmp) + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 3 # 3 updates + + c.flush() + time.sleep(3) + + # Verify existing records were updated (partial update) + for i in range(nb): + rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"book_id == {i}"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + data = rsp['data'][0] + + # Updated field + assert data['book_describe'] == f"updated_book_{i}" + # Unchanged fields + assert data['user_id'] == i + assert data['word_count'] == i * 100 # Original value, not default + assert data['rating'] == 3.5 + i * 0.5 # Original value + + # Test 2: Insert new records with minimal required fields (separate request) + new_record_data = [] + for i in range(10, 12): + tmp = { + "book_id": i, + "user_id": i + 100, + "book_describe": f"new_book_{i}", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist(), + "word_count": None, #should use default value (1000) + "rating": None #nullable, should be null + } + new_record_data.append(tmp) + + payload = { + "collectionName": name, + "data": new_record_data, + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 2 # 2 inserts + + c.flush() + time.sleep(3) + + # Verify new records were inserted with defaults and nulls + for i in range(10, 12): + rsp = self.vector_client.vector_query({"collectionName": name, "filter": f"book_id == {i}"}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 1 + data = rsp['data'][0] + + assert data['book_describe'] == f"new_book_{i}" + assert data['user_id'] == i + 100 + assert data['word_count'] == 1000 # Should use default value + # Note: Nullable field behavior depends on implementation + # It might be null or omitted from result + + logger.info("Partial update with default and nullable fields test passed") + + def test_partial_update_nullable_field_scenarios(self): + """ + Test partial update with nullable fields in various scenarios: + 1. Nullable field with no default value, insert without value, then update to new value + 2. Nullable field with default value, insert without value, then update to new value + 3. Nullable field with no default value, insert with value, then update to null + """ + # Create collection with nullable fields + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "nullable_field_no_default", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}, "nullable": True}, # Nullable, no default + {"fieldName": "nullable_field_with_default", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}, "nullable": True, "defaultValue": "default_value"}, # Nullable with default + {"fieldName": "nullable_field_for_null_update", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}, "nullable": True}, # Nullable, no default + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Scenario 1: Insert data with nullable field (no default) not provided + initial_data_scenario1 = { + "book_id": 1, + "user_id": 1, + "book_describe": "test_book_1", + "nullable_field_with_default": None, # Use default value + "nullable_field_for_null_update": "initial_value", # Will be updated to null later + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + # nullable_field_no_default is not provided, should be null + } + + # Scenario 2: Insert data with nullable field (with default) not provided + initial_data_scenario2 = { + "book_id": 2, + "user_id": 2, + "book_describe": "test_book_2", + "nullable_field_no_default": None, # Should remain null + "nullable_field_for_null_update": "another_initial_value", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + # nullable_field_with_default is not provided, should use default value + } + + # Scenario 3: Insert data with nullable field that will be updated to null + initial_data_scenario3 = { + "book_id": 3, + "user_id": 3, + "book_describe": "test_book_3", + "nullable_field_no_default": None, + "nullable_field_with_default": "custom_value", # Custom value, not default + "nullable_field_for_null_update": "value_to_be_nulled", # Will be updated to null + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + } + + # Insert all initial data + initial_data = [initial_data_scenario1, initial_data_scenario2, initial_data_scenario3] + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + assert rsp['data']['insertCount'] == 3 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Verify initial state + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id >= 1", "outputFields": ["*"]}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 3 + + # Check initial values + data_by_id = {data['book_id']: data for data in rsp['data']} + + # Scenario 1 verification: nullable field with no default should be null/not present + assert data_by_id[1]['nullable_field_no_default'] is None or 'nullable_field_no_default' not in data_by_id[1] + assert data_by_id[1]['nullable_field_with_default'] == "default_value" # Should use default + assert data_by_id[1]['nullable_field_for_null_update'] == "initial_value" + + # Scenario 2 verification: nullable field with default should use default value + assert data_by_id[2]['nullable_field_no_default'] is None or 'nullable_field_no_default' not in data_by_id[2] + assert data_by_id[2]['nullable_field_with_default'] == "default_value" # Should use default + assert data_by_id[2]['nullable_field_for_null_update'] == "another_initial_value" + + # Scenario 3 verification: all fields should have the provided values + assert data_by_id[3]['nullable_field_no_default'] is None or 'nullable_field_no_default' not in data_by_id[3] + assert data_by_id[3]['nullable_field_with_default'] == "custom_value" + assert data_by_id[3]['nullable_field_for_null_update'] == "value_to_be_nulled" + + logger.info("Initial data verification passed") + + # Now perform partial updates for each scenario separately + # Note: Partial update does not support updating different columns for multiple rows in a single request + + # Scenario 1: Update nullable field (no default) from null to new value + partial_update_scenario1 = [{ + "book_id": 1, + "nullable_field_no_default": "updated_value_1" + }] + + payload = { + "collectionName": name, + "data": partial_update_scenario1, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 1 + + c.flush() + time.sleep(2) + + # Scenario 2: Update nullable field (with default) from default to new value + partial_update_scenario2 = [{ + "book_id": 2, + "nullable_field_with_default": "updated_value_2" + }] + + payload = { + "collectionName": name, + "data": partial_update_scenario2, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 1 + + c.flush() + time.sleep(2) + + # Scenario 3: Update nullable field from value to null + partial_update_scenario3 = [{ + "book_id": 3, + "nullable_field_for_null_update": None + }] + + payload = { + "collectionName": name, + "data": partial_update_scenario3, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] == 0 + assert rsp['data']['upsertCount'] == 1 + + c.flush() + time.sleep(2) + + # Verify partial update results + rsp = self.vector_client.vector_query({"collectionName": name, "filter": "book_id >= 1", "outputFields": ["*"]}) + assert rsp['code'] == 0 + assert len(rsp['data']) == 3 + + updated_data_by_id = {data['book_id']: data for data in rsp['data']} + + # Scenario 1: Verify nullable field (no default) was updated from null to new value + assert updated_data_by_id[1]['nullable_field_no_default'] == "updated_value_1" + # Other fields should remain unchanged + assert updated_data_by_id[1]['user_id'] == 1 + assert updated_data_by_id[1]['book_describe'] == "test_book_1" + assert updated_data_by_id[1]['nullable_field_with_default'] == "default_value" + assert updated_data_by_id[1]['nullable_field_for_null_update'] == "initial_value" + + # Scenario 2: Verify nullable field (with default) was updated from default to new value + assert updated_data_by_id[2]['nullable_field_with_default'] == "updated_value_2" + # Other fields should remain unchanged + assert updated_data_by_id[2]['user_id'] == 2 + assert updated_data_by_id[2]['book_describe'] == "test_book_2" + assert updated_data_by_id[2]['nullable_field_no_default'] is None or 'nullable_field_no_default' not in updated_data_by_id[2] + assert updated_data_by_id[2]['nullable_field_for_null_update'] == "another_initial_value" + + # Scenario 3: Verify nullable field was updated from value to null + # Note, the RESTful SDK cannot differentiate between missing fields and fields explicitly set to null, + # so partial update to null values is not supported" + # assert updated_data_by_id[3]['nullable_field_for_null_update'] is None or 'nullable_field_for_null_update' not in updated_data_by_id[3] + # Other fields should remain unchanged + assert updated_data_by_id[3]['user_id'] == 3 + assert updated_data_by_id[3]['book_describe'] == "test_book_3" + assert updated_data_by_id[3]['nullable_field_no_default'] is None or 'nullable_field_no_default' not in updated_data_by_id[3] + assert updated_data_by_id[3]['nullable_field_with_default'] == "custom_value" + + logger.info("All nullable field partial update scenarios passed") + logger.info("Scenario 1: nullable field (no default) null -> new value: PASSED") + logger.info("Scenario 2: nullable field (with default) default -> new value: PASSED") + logger.info("Scenario 3: nullable field value -> null: PASSED") + + +@pytest.mark.L1 +class TestPartialUpdateNegative(TestBase): + + def test_partial_update_without_primary_key(self): + """ + Test partial update fails when primary key is missing + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Try partial update without primary key (should fail) + partial_update_data = [{ + "book_describe": "updated_description" + # Missing book_id (primary key) + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + # Should fail with appropriate error code + assert rsp['code'] != 0 + logger.info(f"Expected failure response: {rsp}") + + def test_partial_update_invalid_collection_name(self): + """ + Test partial update with invalid collection name + """ + partial_update_data = [{ + "book_id": 1, + "book_describe": "updated_description" + }] + + payload = { + "collectionName": "non_existent_collection", + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + assert rsp['code'] != 0 + logger.info(f"Expected failure response: {rsp}") + + def test_partial_update_invalid_field_type(self): + """ + Test partial update with invalid field type + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Try partial update with wrong data type + partial_update_data = [{ + "book_id": 1, + "user_id": "invalid_string_for_int_field" # Should be int, not string + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + # Should fail with appropriate error code + assert rsp['code'] != 0 + logger.info(f"Expected failure response: {rsp}") + + def test_partial_update_empty_data(self): + """ + Test partial update with empty data array + """ + # Create collection (must include vector field) + name = gen_collection_name() + dim = 64 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "vector", "indexName": "vector_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Try partial update with empty data + payload = { + "collectionName": name, + "data": [], # Empty data array + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + # Should fail with appropriate error + assert rsp['code'] != 0 + logger.info(f"Expected failure for empty data: {rsp['message']}") + + def test_partial_update_non_existent_field(self): + """ + Test partial update with non-existent field names + """ + # Create collection (must include vector field) + name = gen_collection_name() + dim = 64 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "vector", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "vector", "indexName": "vector_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Try partial update with non-existent field + partial_update_data = [{ + "book_id": 1, + "non_existent_field": "some_value" # Field doesn't exist in schema + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + # Should fail with appropriate error + assert rsp['code'] != 0 + assert "dynamic schema" in rsp['message'] or "not exist" in rsp['message'] or "unknown" in rsp['message'].lower() + logger.info(f"Expected failure for non-existent field: {rsp['message']}") + + def test_partial_update_mixed_success_failure(self): + """ + Test partial update with mixed valid and invalid records + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "user_id", "dataType": "Int64", "elementTypeParams": {}}, + {"fieldName": "book_describe", "dataType": "VarChar", "elementTypeParams": {"max_length": "256"}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Insert some existing data + initial_data = [{ + "book_id": 1, + "user_id": 1, + "book_describe": "existing_book", + "text_emb": preprocessing.normalize([np.array([random.random() for j in range(dim)])])[0].tolist() + }] + + payload = { + "collectionName": name, + "data": initial_data, + } + rsp = self.vector_client.vector_insert(payload) + assert rsp['code'] == 0 + + c = Collection(name) + c.flush() + time.sleep(3) + + # Mixed partial update: valid existing record update + invalid new record (missing required fields) + mixed_data = [ + { + "book_id": 1, + "book_describe": "updated_existing_book" # Valid partial update for existing record + }, + { + "book_id": 999, + "book_describe": "new_book_missing_fields" # Invalid - missing user_id and text_emb for new record + } + ] + + payload = { + "collectionName": name, + "data": mixed_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + # Should fail because one record is invalid + assert rsp['code'] != 0 + logger.info(f"Expected failure for mixed valid/invalid records: {rsp['message']}") + + def test_partial_update_vector_dimension_mismatch(self): + """ + Test partial update with vector dimension mismatch + """ + # Create collection + name = gen_collection_name() + dim = 128 + payload = { + "collectionName": name, + "schema": { + "fields": [ + {"fieldName": "book_id", "dataType": "Int64", "isPrimary": True, "elementTypeParams": {}}, + {"fieldName": "text_emb", "dataType": "FloatVector", "elementTypeParams": {"dim": f"{dim}"}} + ] + }, + "indexParams": [{"fieldName": "text_emb", "indexName": "text_emb_index", "metricType": "L2"}] + } + rsp = self.collection_client.collection_create(payload) + assert rsp['code'] == 0 + + # Try partial update with wrong vector dimension + partial_update_data = [{ + "book_id": 1, + "text_emb": [random.random() for _ in range(64)] # Wrong dimension (64 instead of 128) + }] + + payload = { + "collectionName": name, + "data": partial_update_data, + "partialUpdate": True + } + rsp = self.vector_client.vector_upsert(payload) + # Should fail with dimension mismatch error + assert rsp['code'] != 0 + assert "dimension" in rsp['message'].lower() or "dim" in rsp['message'].lower() + logger.info(f"Expected failure for dimension mismatch: {rsp['message']}") \ No newline at end of file