milvus/internal/proxy/task_upsert_test.go
congqixia a592cfc8b4
enhance: [2.6] extract shard client logic into dedicated package (#45018) (#45031)
Cherry-pick from master
pr: #45018 #45030
Related to #44761

Refactor proxy shard client management by creating a new
internal/proxy/shardclient package. This improves code organization and
modularity by:

- Moving load balancing logic (LookAsideBalancer, RoundRobinBalancer) to
shardclient package
- Extracting shard client manager and related interfaces into separate
package
- Relocating shard leader management and client lifecycle code
- Adding package documentation (README.md, OWNERS)
- Updating proxy code to use the new shardclient package interfaces

This change makes the shard client functionality more maintainable and
better encapsulated, reducing coupling in the proxy layer.

Also consolidates the proxy package's mockery generation to use a
centralized `.mockery.yaml` configuration file, aligning with the
pattern used by other packages like querycoordv2.

Changes
- **Makefile**: Replace multiple individual mockery commands with a
single config-based invocation for `generate-mockery-proxy` target
- **internal/proxy/.mockery.yaml**: Add mockery configuration defining
all mock interfaces for proxy and proxy/shardclient packages
- **Mock files**: Regenerate mocks using the new configuration:
  - `mock_cache.go`: Clean up by removing unused interface methods
  (credential, shard cache, policy methods)
  - `shardclient/mock_lb_balancer.go`: Update type comments
  (nodeInfo → NodeInfo)
  - `shardclient/mock_lb_policy.go`: Update formatting
  - `shardclient/mock_shardclient_manager.go`: Fix parameter naming
  consistency (nodeInfo1 → nodeInfo)
- **task_search_test.go**: Remove obsolete mock expectations for
deprecated cache methods

Benefits
- Centralized mockery configuration for easier maintenance
- Consistent with other packages (querycoordv2, etc.)
- Cleaner mock interfaces by removing unused methods
- Better type consistency in generated mocks

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2025-10-22 16:06:06 +08:00

1359 lines
44 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"testing"
"github.com/bytedance/mockey"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/shardclient"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/internal/util/segcore"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
func TestUpsertTask_CheckAligned(t *testing.T) {
var err error
// passed NumRows is less than 0
case1 := upsertTask{
req: &milvuspb.UpsertRequest{
NumRows: 0,
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{},
},
},
}
case1.upsertMsg.InsertMsg.InsertRequest = &msgpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
),
CollectionName: case1.req.CollectionName,
PartitionName: case1.req.PartitionName,
FieldsData: case1.req.FieldsData,
NumRows: uint64(case1.req.NumRows),
Version: msgpb.InsertDataVersion_ColumnBased,
}
err = case1.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// checkFieldsDataBySchema was already checked by TestUpsertTask_checkFieldsDataBySchema
boolFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}
int8FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}
int16FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}
int32FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}
int64FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}
floatFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float}
doubleFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Double}
floatVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
binaryVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BinaryVector}
varCharFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}
numRows := 20
dim := 128
collSchema := &schemapb.CollectionSchema{
Name: "TestUpsertTask_checkRowNums",
Description: "TestUpsertTask_checkRowNums",
AutoID: false,
Fields: []*schemapb.FieldSchema{
boolFieldSchema,
int8FieldSchema,
int16FieldSchema,
int32FieldSchema,
int64FieldSchema,
floatFieldSchema,
doubleFieldSchema,
floatVectorFieldSchema,
binaryVectorFieldSchema,
varCharFieldSchema,
},
}
schema := newSchemaInfo(collSchema)
case2 := upsertTask{
req: &milvuspb.UpsertRequest{
NumRows: uint32(numRows),
FieldsData: []*schemapb.FieldData{},
},
rowIDs: testutils.GenerateInt64Array(numRows),
timestamps: testutils.GenerateUint64Array(numRows),
schema: schema,
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{},
},
},
}
// satisfied
case2.req.FieldsData = []*schemapb.FieldData{
newScalarFieldData(boolFieldSchema, "Bool", numRows),
newScalarFieldData(int8FieldSchema, "Int8", numRows),
newScalarFieldData(int16FieldSchema, "Int16", numRows),
newScalarFieldData(int32FieldSchema, "Int32", numRows),
newScalarFieldData(int64FieldSchema, "Int64", numRows),
newScalarFieldData(floatFieldSchema, "Float", numRows),
newScalarFieldData(doubleFieldSchema, "Double", numRows),
newFloatVectorFieldData("FloatVector", numRows, dim),
newBinaryVectorFieldData("BinaryVector", numRows, dim),
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
}
case2.upsertMsg.InsertMsg.InsertRequest = &msgpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
),
CollectionName: case2.req.CollectionName,
PartitionName: case2.req.PartitionName,
FieldsData: case2.req.FieldsData,
NumRows: uint64(case2.req.NumRows),
RowIDs: case2.rowIDs,
Timestamps: case2.timestamps,
Version: msgpb.InsertDataVersion_ColumnBased,
}
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less bool data
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more bool data
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less int8 data
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int8 data
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less int16 data
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int16 data
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less int32 data
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int32 data
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less int64 data
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more int64 data
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less float data
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more float data
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less double data
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less float vectors
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more float vectors
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less binary vectors
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more binary vectors
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
// less double data
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
err = case2.upsertMsg.InsertMsg.CheckAligned()
assert.NoError(t, err)
}
func TestUpsertTask(t *testing.T) {
t.Run("test getChannels", func(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = cache
chMgr := NewMockChannelsMgr(t)
chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil)
ut := upsertTask{
ctx: context.Background(),
req: &milvuspb.UpsertRequest{
CollectionName: collectionName,
},
chMgr: chMgr,
}
err := ut.setChannels()
assert.NoError(t, err)
resChannels := ut.getChannels()
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, ut.pChannels)
})
}
func TestUpsertTaskForReplicate(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
ctx := context.Background()
t.Run("fail to get collection info", func(t *testing.T) {
ut := upsertTask{
ctx: ctx,
req: &milvuspb.UpsertRequest{
CollectionName: "col-0",
},
}
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("foo")).Once()
err := ut.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("replicate mode", func(t *testing.T) {
ut := upsertTask{
ctx: ctx,
req: &milvuspb.UpsertRequest{
CollectionName: "col-0",
},
}
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
replicateID: "local-mac",
}, nil).Once()
err := ut.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestUpsertTask_Function(t *testing.T) {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
ts := embedding.CreateOpenAIEmbeddingServer()
defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
data := []*schemapb.FieldData{}
f1 := schemapb.FieldData{
Type: schemapb.DataType_Int64,
FieldId: 100,
FieldName: "id",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: []int64{0, 1},
},
},
},
},
}
data = append(data, &f1)
f2 := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
FieldName: "text",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"sentence", "sentence"},
},
},
},
},
}
data = append(data, &f2)
collectionName := "TestUpsertTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "TestUpsertTask_function",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
},
},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "credential", Value: "mock"},
{Key: "dim", Value: "4"},
},
},
},
}
info := newSchemaInfo(schema)
collectionID := UniqueID(0)
cache := NewMockCache(t)
globalMetaCache = cache
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
Status: merr.Status(nil),
ID: collectionID,
Count: 10,
}, nil)
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
idAllocator.Start()
defer idAllocator.Close()
assert.NoError(t, err)
task := upsertTask{
ctx: context.Background(),
req: &milvuspb.UpsertRequest{
CollectionName: collectionName,
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
),
CollectionName: collectionName,
DbName: "hooooooo",
Version: msgpb.InsertDataVersion_ColumnBased,
FieldsData: data,
NumRows: 2,
PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(),
},
},
},
idAllocator: idAllocator,
schema: info,
result: &milvuspb.MutationResult{},
}
err = genFunctionFields(task.ctx, task.upsertMsg.InsertMsg, task.schema, task.req.GetPartialUpdate())
assert.NoError(t, err)
err = task.insertPreExecute(ctx)
assert.NoError(t, err)
// process failed
{
oldRows := task.upsertMsg.InsertMsg.InsertRequest.NumRows
task.upsertMsg.InsertMsg.InsertRequest.NumRows = 10000
err = task.insertPreExecute(ctx)
assert.Error(t, err)
task.upsertMsg.InsertMsg.InsertRequest.NumRows = oldRows
}
}
func TestUpsertTaskForSchemaMismatch(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
ctx := context.Background()
t.Run("schema ts mismatch", func(t *testing.T) {
ut := upsertTask{
ctx: ctx,
req: &milvuspb.UpsertRequest{
CollectionName: "col-0",
NumRows: 10,
},
schemaTimestamp: 99,
}
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
updateTimestamp: 100,
}, nil)
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
err := ut.PreExecute(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrCollectionSchemaMismatch)
})
}
// Helper function to create test updateTask
func createTestUpdateTask() *upsertTask {
mcClient := &grpcmixcoordclient.Client{}
upsertTask := &upsertTask{
baseTask: baseTask{},
Condition: NewTaskCondition(context.Background()),
req: &milvuspb.UpsertRequest{
DbName: "test_db",
CollectionName: "test_collection",
PartitionName: "_default",
FieldsData: []*schemapb.FieldData{
{
FieldName: "id",
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
},
},
},
},
{
FieldName: "name",
FieldId: 102,
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}},
},
},
},
},
{
FieldName: "vector",
FieldId: 101,
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{Data: make([]float32, 384)}, // 3 * 128
},
},
},
},
},
NumRows: 3,
},
ctx: context.Background(),
schema: createTestSchema(),
collectionID: 1001,
node: &Proxy{
mixCoord: mcClient,
lbPolicy: shardclient.NewLBPolicyImpl(nil),
},
}
return upsertTask
}
// Helper function to create test schema
func createTestSchema() *schemaInfo {
schema := &schemapb.CollectionSchema{
Name: "test_collection",
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: "id",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
},
{
FieldID: 101,
Name: "vector",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "128"},
},
},
{
FieldID: 102,
Name: "name",
DataType: schemapb.DataType_VarChar,
},
},
}
return newSchemaInfo(schema)
}
func TestRetrieveByPKs_Success(t *testing.T) {
mockey.PatchConvey("TestRetrieveByPKs_Success", t, func() {
// Setup mocks
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
FieldID: 100,
Name: "id",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
}, nil).Build()
mockey.Mock(validatePartitionTag).Return(nil).Build()
mockey.Mock((*MetaCache).GetPartitionID).Return(int64(1002), nil).Build()
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{
{
FieldName: "id",
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
},
},
},
},
},
}, segcore.StorageCost{}, nil).Build()
globalMetaCache = &MetaCache{}
mockey.Mock(globalMetaCache.GetPartitionID).Return(int64(1002), nil).Build()
// Execute test
task := createTestUpdateTask()
task.partitionKeyMode = false
task.upsertMsg = &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
PartitionName: "_default",
},
},
DeleteMsg: &msgstream.DeleteMsg{
DeleteRequest: &msgpb.DeleteRequest{
PartitionName: "_default",
},
},
}
ids := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
},
}
result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
// Verify results
assert.NoError(t, err)
assert.NotNil(t, result)
assert.Equal(t, commonpb.ErrorCode_Success, result.Status.ErrorCode)
assert.Len(t, result.FieldsData, 1)
})
}
func TestRetrieveByPKs_GetPrimaryFieldSchemaError(t *testing.T) {
mockey.PatchConvey("TestRetrieveByPKs_GetPrimaryFieldSchemaError", t, func() {
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
task := createTestUpdateTask()
ids := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
},
}
result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
assert.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "primary field not found")
})
}
func TestRetrieveByPKs_PartitionKeyMode(t *testing.T) {
mockey.PatchConvey("TestRetrieveByPKs_PartitionKeyMode", t, func() {
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
FieldID: 100,
Name: "id",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
}, nil).Build()
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{},
}, segcore.StorageCost{}, nil).Build()
task := createTestUpdateTask()
task.partitionKeyMode = true
ids := &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
},
}
result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
assert.NoError(t, err)
assert.NotNil(t, result)
})
}
func TestUpdateTask_queryPreExecute_Success(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_queryPreExecute_Success", t, func() {
// Setup mocks
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
FieldID: 100,
Name: "id",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
}, nil).Build()
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
FieldName: "id",
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
},
},
},
}, nil).Build()
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}},
},
}, nil).Build()
mockey.Mock(typeutil.GetSizeOfIDs).Return(3).Build()
mockey.Mock(retrieveByPKs).Return(&milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{
{
FieldName: "id",
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
},
},
},
},
{
FieldName: "name",
FieldId: 102,
Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}},
},
},
},
},
{
FieldName: "vector",
FieldId: 101,
Type: schemapb.DataType_FloatVector,
Field: &schemapb.FieldData_Vectors{
Vectors: &schemapb.VectorField{
Dim: 128,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{Data: make([]float32, 256)}, // 2 * 128
},
},
},
},
},
}, segcore.StorageCost{}, nil).Build()
mockey.Mock(typeutil.NewIDsChecker).Return(&typeutil.IDsChecker{}, nil).Build()
// Execute test
task := createTestUpdateTask()
task.schema = createTestSchema()
task.upsertMsg = &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
FieldsData: []*schemapb.FieldData{
{
FieldName: "id",
FieldId: 100,
Type: schemapb.DataType_Int64,
},
{
FieldName: "name",
FieldId: 102,
Type: schemapb.DataType_VarChar,
},
{
FieldName: "vector",
FieldId: 101,
Type: schemapb.DataType_FloatVector,
},
},
},
},
}
err := task.queryPreExecute(context.Background())
// Verify results
assert.NoError(t, err)
assert.NotNil(t, task.deletePKs)
assert.NotNil(t, task.insertFieldData)
})
}
func TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError", t, func() {
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
task := createTestUpdateTask()
task.schema = createTestSchema()
err := task.queryPreExecute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "primary field not found")
})
}
func TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError", t, func() {
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
FieldID: 100,
Name: "id",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
}, nil).Build()
expectedErr := merr.WrapErrParameterInvalidMsg("primary field data not found")
mockey.Mock(typeutil.GetPrimaryFieldData).Return(nil, expectedErr).Build()
task := createTestUpdateTask()
task.schema = createTestSchema()
err := task.queryPreExecute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "must assign pk when upsert")
})
}
func TestUpdateTask_queryPreExecute_EmptyOldIDs(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_queryPreExecute_EmptyOldIDs", t, func() {
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
FieldID: 100,
Name: "id",
IsPrimaryKey: true,
DataType: schemapb.DataType_Int64,
}, nil).Build()
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
FieldName: "id",
FieldId: 100,
Type: schemapb.DataType_Int64,
}, nil).Build()
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{}, nil).Build()
mockey.Mock(typeutil.GetSizeOfIDs).Return(0).Build()
task := createTestUpdateTask()
task.schema = createTestSchema()
err := task.queryPreExecute(context.Background())
assert.NoError(t, err)
assert.NotNil(t, task.deletePKs)
assert.Equal(t, task.req.GetFieldsData(), task.insertFieldData)
})
}
func TestUpdateTask_PreExecute_Success(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_PreExecute_Success", t, func() {
// Setup mocks
globalMetaCache = &MetaCache{}
mockey.Mock(GetReplicateID).Return("", nil).Build()
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
updateTimestamp: 12345,
}, nil).Build()
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
name: "_default",
}, nil).Build()
mockey.Mock((*upsertTask).queryPreExecute).Return(nil).Build()
mockey.Mock((*upsertTask).insertPreExecute).Return(nil).Build()
mockey.Mock((*upsertTask).deletePreExecute).Return(nil).Build()
// Execute test
task := createTestUpdateTask()
task.req.PartialUpdate = true
err := task.PreExecute(context.Background())
// Verify results
assert.NoError(t, err)
assert.NotNil(t, task.result)
assert.Equal(t, int64(1001), task.collectionID)
assert.NotNil(t, task.schema)
assert.NotNil(t, task.upsertMsg)
})
}
func TestUpdateTask_PreExecute_ReplicateIDError(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_PreExecute_ReplicateIDError", t, func() {
globalMetaCache = &MetaCache{}
mockey.Mock(GetReplicateID).Return("replica1", nil).Build()
task := createTestUpdateTask()
err := task.PreExecute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "can't operate on the collection under standby mode")
})
}
func TestUpdateTask_PreExecute_GetCollectionIDError(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_PreExecute_GetCollectionIDError", t, func() {
globalMetaCache = &MetaCache{}
mockey.Mock(GetReplicateID).Return("", nil).Build()
expectedErr := merr.WrapErrCollectionNotFound("test_collection")
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(0), expectedErr).Build()
task := createTestUpdateTask()
err := task.PreExecute(context.Background())
assert.Error(t, err)
})
}
func TestUpdateTask_PreExecute_PartitionKeyModeError(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_PreExecute_PartitionKeyModeError", t, func() {
globalMetaCache = &MetaCache{}
mockey.Mock(GetReplicateID).Return("", nil).Build()
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
updateTimestamp: 12345,
}, nil).Build()
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
mockey.Mock(isPartitionKeyMode).Return(true, nil).Build()
task := createTestUpdateTask()
task.req.PartitionName = "custom_partition" // This should cause error in partition key mode
err := task.PreExecute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "not support manually specifying the partition names if partition key mode is used")
})
}
func TestUpdateTask_PreExecute_InvalidNumRows(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_PreExecute_InvalidNumRows", t, func() {
globalMetaCache = &MetaCache{}
mockey.Mock(GetReplicateID).Return("", nil).Build()
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
updateTimestamp: 12345,
}, nil).Build()
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
name: "_default",
}, nil).Build()
task := createTestUpdateTask()
task.req.NumRows = 0 // Invalid num_rows
err := task.PreExecute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid num_rows")
})
}
func TestUpdateTask_PreExecute_QueryPreExecuteError(t *testing.T) {
mockey.PatchConvey("TestUpdateTask_PreExecute_QueryPreExecuteError", t, func() {
globalMetaCache = &MetaCache{}
mockey.Mock(GetReplicateID).Return("", nil).Build()
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
updateTimestamp: 12345,
}, nil).Build()
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
name: "_default",
}, nil).Build()
expectedErr := merr.WrapErrParameterInvalidMsg("query pre-execute failed")
mockey.Mock((*upsertTask).queryPreExecute).Return(expectedErr).Build()
task := createTestUpdateTask()
task.req.PartialUpdate = true
err := task.PreExecute(context.Background())
assert.Error(t, err)
assert.Contains(t, err.Error(), "query pre-execute failed")
})
}
func TestUpsertTask_queryPreExecute_MixLogic(t *testing.T) {
// Schema for the test collection
schema := newSchemaInfo(&schemapb.CollectionSchema{
Name: "test_merge_collection",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
},
})
// Upsert IDs: 1 (update), 2 (update), 3 (insert)
upsertData := []*schemapb.FieldData{
{
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}},
},
{
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{100, 200, 300}}}}},
},
}
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
// Query result for existing PKs: 1, 2
mockQueryResult := &milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{
{
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2}}}}},
},
{
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{10, 20}}}}},
},
{
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}}}}},
ValidData: []bool{true, true},
},
},
}
task := &upsertTask{
ctx: context.Background(),
schema: schema,
req: &milvuspb.UpsertRequest{
FieldsData: upsertData,
NumRows: uint32(numRows),
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
FieldsData: upsertData,
NumRows: numRows,
},
},
},
node: &Proxy{},
}
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build()
defer mockRetrieve.UnPatch()
err := task.queryPreExecute(context.Background())
assert.NoError(t, err)
// Verify delete PKs
deletePks := task.deletePKs.GetIntId().GetData()
assert.ElementsMatch(t, []int64{1, 2}, deletePks)
// Verify merged insert data
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
assert.NoError(t, err)
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
assert.NoError(t, err)
ids, err := parsePrimaryFieldData2IDs(idField)
assert.NoError(t, err)
insertPKs := ids.GetIntId().GetData()
assert.Equal(t, []int64{1, 2, 3}, insertPKs)
var valueField *schemapb.FieldData
for _, f := range task.insertFieldData {
if f.GetFieldName() == "value" {
valueField = f
break
}
}
assert.NotNil(t, valueField)
assert.Equal(t, []int32{100, 200, 300}, valueField.GetScalars().GetIntData().GetData())
}
func TestUpsertTask_queryPreExecute_PureInsert(t *testing.T) {
// Schema for the test collection
schema := newSchemaInfo(&schemapb.CollectionSchema{
Name: "test_merge_collection",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
},
})
// Upsert IDs: 4, 5
upsertData := []*schemapb.FieldData{
{
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{4, 5}}}}},
},
{
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{400, 500}}}}},
},
}
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
// Query result is empty, but schema is preserved
mockQueryResult := &milvuspb.QueryResults{Status: merr.Success(), FieldsData: []*schemapb.FieldData{
{
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{}}}}},
},
{
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{}}}}},
},
{
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{}}}}},
},
}}
task := &upsertTask{
ctx: context.Background(),
schema: schema,
req: &milvuspb.UpsertRequest{
FieldsData: upsertData,
NumRows: uint32(numRows),
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
FieldsData: upsertData,
NumRows: numRows,
},
},
},
node: &Proxy{},
}
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build()
defer mockRetrieve.UnPatch()
err := task.queryPreExecute(context.Background())
assert.NoError(t, err)
// Verify delete PKs
deletePks := task.deletePKs.GetIntId().GetData()
assert.Empty(t, deletePks)
// Verify merged insert data
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
assert.NoError(t, err)
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
assert.NoError(t, err)
ids, err := parsePrimaryFieldData2IDs(idField)
assert.NoError(t, err)
insertPKs := ids.GetIntId().GetData()
assert.Equal(t, []int64{4, 5}, insertPKs)
var valueField *schemapb.FieldData
for _, f := range task.insertFieldData {
if f.GetFieldName() == "value" {
valueField = f
break
}
}
assert.NotNil(t, valueField)
assert.Equal(t, []int32{400, 500}, valueField.GetScalars().GetIntData().GetData())
}
func TestUpsertTask_queryPreExecute_PureUpdate(t *testing.T) {
// Schema for the test collection
schema := newSchemaInfo(&schemapb.CollectionSchema{
Name: "test_merge_collection",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
},
})
// Upsert IDs: 6, 7
upsertData := []*schemapb.FieldData{
{
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
},
{
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{600, 700}}}}},
},
}
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
// Query result for existing PKs: 6, 7
mockQueryResult := &milvuspb.QueryResults{
Status: merr.Success(),
FieldsData: []*schemapb.FieldData{
{
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
},
{
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{60, 70}}}}},
},
},
}
task := &upsertTask{
ctx: context.Background(),
schema: schema,
req: &milvuspb.UpsertRequest{
FieldsData: upsertData,
NumRows: uint32(numRows),
},
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
FieldsData: upsertData,
NumRows: numRows,
},
},
},
node: &Proxy{},
}
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build()
defer mockRetrieve.UnPatch()
err := task.queryPreExecute(context.Background())
assert.NoError(t, err)
// Verify delete PKs
deletePks := task.deletePKs.GetIntId().GetData()
assert.ElementsMatch(t, []int64{6, 7}, deletePks)
// Verify merged insert data
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
assert.NoError(t, err)
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
assert.NoError(t, err)
ids, err := parsePrimaryFieldData2IDs(idField)
assert.NoError(t, err)
insertPKs := ids.GetIntId().GetData()
assert.Equal(t, []int64{6, 7}, insertPKs)
var valueField *schemapb.FieldData
for _, f := range task.insertFieldData {
if f.GetFieldName() == "value" {
valueField = f
break
}
}
assert.NotNil(t, valueField)
assert.Equal(t, []int32{600, 700}, valueField.GetScalars().GetIntData().GetData())
}