mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
issue: #43980 Fixes a panic that occurred when a partial update was converted to an insert due to a non-existent primary key. The panic was caused by missing nullable fields that were not provided in the original partial update request. The upsert pre-execution logic is refactored to handle this correctly: - Explicitly splits upsert data into 'insert' and 'update' batches. - Automatically generates data for missing nullable or default-value fields during inserts, preventing the panic. - Enhances `typeutil.UpdateFieldData` to support different source and destination indexes for flexible data merging. - Adds comprehensive unit tests for mixed upsert, pure insert, and pure update scenarios. --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
1355 lines
44 KiB
Go
1355 lines
44 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/bytedance/mockey"
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/allocator"
|
|
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
|
"github.com/milvus-io/milvus/internal/mocks"
|
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
|
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
|
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
func TestUpsertTask_CheckAligned(t *testing.T) {
|
|
var err error
|
|
|
|
// passed NumRows is less than 0
|
|
case1 := upsertTask{
|
|
req: &milvuspb.UpsertRequest{
|
|
NumRows: 0,
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{},
|
|
},
|
|
},
|
|
}
|
|
case1.upsertMsg.InsertMsg.InsertRequest = &msgpb.InsertRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
|
),
|
|
CollectionName: case1.req.CollectionName,
|
|
PartitionName: case1.req.PartitionName,
|
|
FieldsData: case1.req.FieldsData,
|
|
NumRows: uint64(case1.req.NumRows),
|
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
|
}
|
|
|
|
err = case1.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// checkFieldsDataBySchema was already checked by TestUpsertTask_checkFieldsDataBySchema
|
|
|
|
boolFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}
|
|
int8FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}
|
|
int16FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}
|
|
int32FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}
|
|
int64FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}
|
|
floatFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float}
|
|
doubleFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Double}
|
|
floatVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
|
|
binaryVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BinaryVector}
|
|
varCharFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}
|
|
|
|
numRows := 20
|
|
dim := 128
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Name: "TestUpsertTask_checkRowNums",
|
|
Description: "TestUpsertTask_checkRowNums",
|
|
AutoID: false,
|
|
Fields: []*schemapb.FieldSchema{
|
|
boolFieldSchema,
|
|
int8FieldSchema,
|
|
int16FieldSchema,
|
|
int32FieldSchema,
|
|
int64FieldSchema,
|
|
floatFieldSchema,
|
|
doubleFieldSchema,
|
|
floatVectorFieldSchema,
|
|
binaryVectorFieldSchema,
|
|
varCharFieldSchema,
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
case2 := upsertTask{
|
|
req: &milvuspb.UpsertRequest{
|
|
NumRows: uint32(numRows),
|
|
FieldsData: []*schemapb.FieldData{},
|
|
},
|
|
rowIDs: testutils.GenerateInt64Array(numRows),
|
|
timestamps: testutils.GenerateUint64Array(numRows),
|
|
schema: schema,
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{},
|
|
},
|
|
},
|
|
}
|
|
|
|
// satisfied
|
|
case2.req.FieldsData = []*schemapb.FieldData{
|
|
newScalarFieldData(boolFieldSchema, "Bool", numRows),
|
|
newScalarFieldData(int8FieldSchema, "Int8", numRows),
|
|
newScalarFieldData(int16FieldSchema, "Int16", numRows),
|
|
newScalarFieldData(int32FieldSchema, "Int32", numRows),
|
|
newScalarFieldData(int64FieldSchema, "Int64", numRows),
|
|
newScalarFieldData(floatFieldSchema, "Float", numRows),
|
|
newScalarFieldData(doubleFieldSchema, "Double", numRows),
|
|
newFloatVectorFieldData("FloatVector", numRows, dim),
|
|
newBinaryVectorFieldData("BinaryVector", numRows, dim),
|
|
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
|
|
}
|
|
case2.upsertMsg.InsertMsg.InsertRequest = &msgpb.InsertRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
|
),
|
|
CollectionName: case2.req.CollectionName,
|
|
PartitionName: case2.req.PartitionName,
|
|
FieldsData: case2.req.FieldsData,
|
|
NumRows: uint64(case2.req.NumRows),
|
|
RowIDs: case2.rowIDs,
|
|
Timestamps: case2.timestamps,
|
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
|
}
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less bool data
|
|
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more bool data
|
|
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, nil, err)
|
|
|
|
// less int8 data
|
|
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int8 data
|
|
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less int16 data
|
|
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int16 data
|
|
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less int32 data
|
|
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int32 data
|
|
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less int64 data
|
|
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int64 data
|
|
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less float data
|
|
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more float data
|
|
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, nil, err)
|
|
|
|
// less double data
|
|
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more double data
|
|
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, nil, err)
|
|
|
|
// less float vectors
|
|
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more float vectors
|
|
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less binary vectors
|
|
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more binary vectors
|
|
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less double data
|
|
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more double data
|
|
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestUpsertTask(t *testing.T) {
|
|
t.Run("test getChannels", func(t *testing.T) {
|
|
collectionID := UniqueID(0)
|
|
collectionName := "col-0"
|
|
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
|
cache := NewMockCache(t)
|
|
cache.On("GetCollectionID",
|
|
mock.Anything, // context.Context
|
|
mock.AnythingOfType("string"),
|
|
mock.AnythingOfType("string"),
|
|
).Return(collectionID, nil)
|
|
globalMetaCache = cache
|
|
|
|
chMgr := NewMockChannelsMgr(t)
|
|
chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil)
|
|
ut := upsertTask{
|
|
ctx: context.Background(),
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: collectionName,
|
|
},
|
|
chMgr: chMgr,
|
|
}
|
|
err := ut.setChannels()
|
|
assert.NoError(t, err)
|
|
resChannels := ut.getChannels()
|
|
assert.ElementsMatch(t, channels, resChannels)
|
|
assert.ElementsMatch(t, channels, ut.pChannels)
|
|
})
|
|
}
|
|
|
|
func TestUpsertTaskForReplicate(t *testing.T) {
|
|
cache := globalMetaCache
|
|
defer func() { globalMetaCache = cache }()
|
|
mockCache := NewMockCache(t)
|
|
globalMetaCache = mockCache
|
|
ctx := context.Background()
|
|
|
|
t.Run("fail to get collection info", func(t *testing.T) {
|
|
ut := upsertTask{
|
|
ctx: ctx,
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: "col-0",
|
|
},
|
|
}
|
|
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("foo")).Once()
|
|
err := ut.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("replicate mode", func(t *testing.T) {
|
|
ut := upsertTask{
|
|
ctx: ctx,
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: "col-0",
|
|
},
|
|
}
|
|
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
|
replicateID: "local-mac",
|
|
}, nil).Once()
|
|
err := ut.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestUpsertTask_Function(t *testing.T) {
|
|
paramtable.Init()
|
|
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
|
|
return map[string]string{
|
|
"mock.apikey": "mock",
|
|
}
|
|
}
|
|
ts := embedding.CreateOpenAIEmbeddingServer()
|
|
defer ts.Close()
|
|
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
|
|
return map[string]string{
|
|
"openai.url": ts.URL,
|
|
}
|
|
}
|
|
|
|
data := []*schemapb.FieldData{}
|
|
f1 := schemapb.FieldData{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldId: 100,
|
|
FieldName: "id",
|
|
IsDynamic: false,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: []int64{0, 1},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
data = append(data, &f1)
|
|
f2 := schemapb.FieldData{
|
|
Type: schemapb.DataType_VarChar,
|
|
FieldId: 101,
|
|
FieldName: "text",
|
|
IsDynamic: false,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: []string{"sentence", "sentence"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
data = append(data, &f2)
|
|
collectionName := "TestUpsertTask_function"
|
|
schema := &schemapb.CollectionSchema{
|
|
Name: collectionName,
|
|
Description: "TestUpsertTask_function",
|
|
AutoID: true,
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
|
|
{
|
|
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "max_length", Value: "200"},
|
|
},
|
|
},
|
|
{
|
|
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "dim", Value: "4"},
|
|
},
|
|
IsFunctionOutput: true,
|
|
},
|
|
},
|
|
Functions: []*schemapb.FunctionSchema{
|
|
{
|
|
Name: "test_function",
|
|
Type: schemapb.FunctionType_TextEmbedding,
|
|
InputFieldIds: []int64{101},
|
|
InputFieldNames: []string{"text"},
|
|
OutputFieldIds: []int64{102},
|
|
OutputFieldNames: []string{"vector"},
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "provider", Value: "openai"},
|
|
{Key: "model_name", Value: "text-embedding-ada-002"},
|
|
{Key: "credential", Value: "mock"},
|
|
{Key: "dim", Value: "4"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
info := newSchemaInfo(schema)
|
|
collectionID := UniqueID(0)
|
|
cache := NewMockCache(t)
|
|
globalMetaCache = cache
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
rc := mocks.NewMockRootCoordClient(t)
|
|
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
|
|
Status: merr.Status(nil),
|
|
ID: collectionID,
|
|
Count: 10,
|
|
}, nil)
|
|
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
|
|
idAllocator.Start()
|
|
defer idAllocator.Close()
|
|
assert.NoError(t, err)
|
|
task := upsertTask{
|
|
ctx: context.Background(),
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: collectionName,
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
|
),
|
|
CollectionName: collectionName,
|
|
DbName: "hooooooo",
|
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
|
FieldsData: data,
|
|
NumRows: 2,
|
|
PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(),
|
|
},
|
|
},
|
|
},
|
|
idAllocator: idAllocator,
|
|
schema: info,
|
|
result: &milvuspb.MutationResult{},
|
|
}
|
|
err = task.insertPreExecute(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
// process failed
|
|
{
|
|
oldRows := task.upsertMsg.InsertMsg.InsertRequest.NumRows
|
|
task.upsertMsg.InsertMsg.InsertRequest.NumRows = 10000
|
|
err = task.insertPreExecute(ctx)
|
|
assert.Error(t, err)
|
|
task.upsertMsg.InsertMsg.InsertRequest.NumRows = oldRows
|
|
}
|
|
}
|
|
|
|
func TestUpsertTaskForSchemaMismatch(t *testing.T) {
|
|
cache := globalMetaCache
|
|
defer func() { globalMetaCache = cache }()
|
|
mockCache := NewMockCache(t)
|
|
globalMetaCache = mockCache
|
|
ctx := context.Background()
|
|
|
|
t.Run("schema ts mismatch", func(t *testing.T) {
|
|
ut := upsertTask{
|
|
ctx: ctx,
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: "col-0",
|
|
NumRows: 10,
|
|
},
|
|
schemaTimestamp: 99,
|
|
}
|
|
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil)
|
|
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
|
updateTimestamp: 100,
|
|
}, nil)
|
|
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
|
|
err := ut.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
assert.ErrorIs(t, err, merr.ErrCollectionSchemaMismatch)
|
|
})
|
|
}
|
|
|
|
// Helper function to create test updateTask
|
|
func createTestUpdateTask() *upsertTask {
|
|
mcClient := &grpcmixcoordclient.Client{}
|
|
|
|
upsertTask := &upsertTask{
|
|
baseTask: baseTask{},
|
|
Condition: NewTaskCondition(context.Background()),
|
|
req: &milvuspb.UpsertRequest{
|
|
DbName: "test_db",
|
|
CollectionName: "test_collection",
|
|
PartitionName: "_default",
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "name",
|
|
FieldId: 102,
|
|
Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
FieldId: 101,
|
|
Type: schemapb.DataType_FloatVector,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: 128,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{Data: make([]float32, 384)}, // 3 * 128
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
NumRows: 3,
|
|
},
|
|
ctx: context.Background(),
|
|
schema: createTestSchema(),
|
|
collectionID: 1001,
|
|
node: &Proxy{
|
|
mixCoord: mcClient,
|
|
lbPolicy: NewLBPolicyImpl(nil),
|
|
},
|
|
}
|
|
|
|
return upsertTask
|
|
}
|
|
|
|
// Helper function to create test schema
|
|
func createTestSchema() *schemaInfo {
|
|
schema := &schemapb.CollectionSchema{
|
|
Name: "test_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
},
|
|
{
|
|
FieldID: 101,
|
|
Name: "vector",
|
|
DataType: schemapb.DataType_FloatVector,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "dim", Value: "128"},
|
|
},
|
|
},
|
|
{
|
|
FieldID: 102,
|
|
Name: "name",
|
|
DataType: schemapb.DataType_VarChar,
|
|
},
|
|
},
|
|
}
|
|
return newSchemaInfo(schema)
|
|
}
|
|
|
|
func TestRetrieveByPKs_Success(t *testing.T) {
|
|
mockey.PatchConvey("TestRetrieveByPKs_Success", t, func() {
|
|
// Setup mocks
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(validatePartitionTag).Return(nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetPartitionID).Return(int64(1002), nil).Build()
|
|
|
|
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
|
|
|
|
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil).Build()
|
|
|
|
globalMetaCache = &MetaCache{}
|
|
mockey.Mock(globalMetaCache.GetPartitionID).Return(int64(1002), nil).Build()
|
|
|
|
// Execute test
|
|
task := createTestUpdateTask()
|
|
task.partitionKeyMode = false
|
|
task.upsertMsg = &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
PartitionName: "_default",
|
|
},
|
|
},
|
|
DeleteMsg: &msgstream.DeleteMsg{
|
|
DeleteRequest: &msgpb.DeleteRequest{
|
|
PartitionName: "_default",
|
|
},
|
|
},
|
|
}
|
|
|
|
ids := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
}
|
|
|
|
result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
|
|
|
// Verify results
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.Equal(t, commonpb.ErrorCode_Success, result.Status.ErrorCode)
|
|
assert.Len(t, result.FieldsData, 1)
|
|
})
|
|
}
|
|
|
|
func TestRetrieveByPKs_GetPrimaryFieldSchemaError(t *testing.T) {
|
|
mockey.PatchConvey("TestRetrieveByPKs_GetPrimaryFieldSchemaError", t, func() {
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
ids := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
}
|
|
|
|
result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "primary field not found")
|
|
})
|
|
}
|
|
|
|
func TestRetrieveByPKs_PartitionKeyMode(t *testing.T) {
|
|
mockey.PatchConvey("TestRetrieveByPKs_PartitionKeyMode", t, func() {
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
|
|
|
|
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{},
|
|
}, nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.partitionKeyMode = true
|
|
|
|
ids := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
}
|
|
|
|
result, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_Success(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_Success", t, func() {
|
|
// Setup mocks
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
|
},
|
|
},
|
|
},
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
|
},
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetSizeOfIDs).Return(3).Build()
|
|
|
|
mockey.Mock(retrieveByPKs).Return(&milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "name",
|
|
FieldId: 102,
|
|
Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
FieldId: 101,
|
|
Type: schemapb.DataType_FloatVector,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: 128,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{Data: make([]float32, 256)}, // 2 * 128
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.NewIDsChecker).Return(&typeutil.IDsChecker{}, nil).Build()
|
|
|
|
// Execute test
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
task.upsertMsg = &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
},
|
|
{
|
|
FieldName: "name",
|
|
FieldId: 102,
|
|
Type: schemapb.DataType_VarChar,
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
FieldId: 101,
|
|
Type: schemapb.DataType_FloatVector,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
// Verify results
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, task.deletePKs)
|
|
assert.NotNil(t, task.insertFieldData)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError", t, func() {
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "primary field not found")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError", t, func() {
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("primary field data not found")
|
|
mockey.Mock(typeutil.GetPrimaryFieldData).Return(nil, expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "must assign pk when upsert")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_EmptyOldIDs(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_EmptyOldIDs", t, func() {
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetSizeOfIDs).Return(0).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, task.deletePKs)
|
|
assert.Equal(t, task.req.GetFieldsData(), task.insertFieldData)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_Success(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_Success", t, func() {
|
|
// Setup mocks
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
|
|
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
|
name: "_default",
|
|
}, nil).Build()
|
|
|
|
mockey.Mock((*upsertTask).queryPreExecute).Return(nil).Build()
|
|
|
|
mockey.Mock((*upsertTask).insertPreExecute).Return(nil).Build()
|
|
|
|
mockey.Mock((*upsertTask).deletePreExecute).Return(nil).Build()
|
|
|
|
// Execute test
|
|
task := createTestUpdateTask()
|
|
task.req.PartialUpdate = true
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
// Verify results
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, task.result)
|
|
assert.Equal(t, int64(1001), task.collectionID)
|
|
assert.NotNil(t, task.schema)
|
|
assert.NotNil(t, task.upsertMsg)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_ReplicateIDError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_ReplicateIDError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("replica1", nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "can't operate on the collection under standby mode")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_GetCollectionIDError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_GetCollectionIDError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
|
|
expectedErr := merr.WrapErrCollectionNotFound("test_collection")
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(0), expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_PartitionKeyModeError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_PartitionKeyModeError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
|
|
mockey.Mock(isPartitionKeyMode).Return(true, nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.req.PartitionName = "custom_partition" // This should cause error in partition key mode
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "not support manually specifying the partition names if partition key mode is used")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_InvalidNumRows(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_InvalidNumRows", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
|
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
|
name: "_default",
|
|
}, nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.req.NumRows = 0 // Invalid num_rows
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid num_rows")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_QueryPreExecuteError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_QueryPreExecuteError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
|
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
|
name: "_default",
|
|
}, nil).Build()
|
|
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("query pre-execute failed")
|
|
mockey.Mock((*upsertTask).queryPreExecute).Return(expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.req.PartialUpdate = true
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "query pre-execute failed")
|
|
})
|
|
}
|
|
|
|
func TestUpsertTask_queryPreExecute_MixLogic(t *testing.T) {
|
|
// Schema for the test collection
|
|
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
|
Name: "test_merge_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
|
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
|
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
|
},
|
|
})
|
|
|
|
// Upsert IDs: 1 (update), 2 (update), 3 (insert)
|
|
upsertData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{100, 200, 300}}}}},
|
|
},
|
|
}
|
|
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
|
|
|
// Query result for existing PKs: 1, 2
|
|
mockQueryResult := &milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{10, 20}}}}},
|
|
},
|
|
{
|
|
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}}}}},
|
|
ValidData: []bool{true, true},
|
|
},
|
|
},
|
|
}
|
|
|
|
task := &upsertTask{
|
|
ctx: context.Background(),
|
|
schema: schema,
|
|
req: &milvuspb.UpsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: uint32(numRows),
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: numRows,
|
|
},
|
|
},
|
|
},
|
|
node: &Proxy{},
|
|
}
|
|
|
|
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build()
|
|
defer mockRetrieve.UnPatch()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
assert.NoError(t, err)
|
|
|
|
// Verify delete PKs
|
|
deletePks := task.deletePKs.GetIntId().GetData()
|
|
assert.ElementsMatch(t, []int64{1, 2}, deletePks)
|
|
|
|
// Verify merged insert data
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
|
assert.NoError(t, err)
|
|
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
|
assert.NoError(t, err)
|
|
ids, err := parsePrimaryFieldData2IDs(idField)
|
|
assert.NoError(t, err)
|
|
insertPKs := ids.GetIntId().GetData()
|
|
assert.Equal(t, []int64{1, 2, 3}, insertPKs)
|
|
|
|
var valueField *schemapb.FieldData
|
|
for _, f := range task.insertFieldData {
|
|
if f.GetFieldName() == "value" {
|
|
valueField = f
|
|
break
|
|
}
|
|
}
|
|
assert.NotNil(t, valueField)
|
|
assert.Equal(t, []int32{100, 200, 300}, valueField.GetScalars().GetIntData().GetData())
|
|
}
|
|
|
|
func TestUpsertTask_queryPreExecute_PureInsert(t *testing.T) {
|
|
// Schema for the test collection
|
|
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
|
Name: "test_merge_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
|
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
|
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
|
},
|
|
})
|
|
|
|
// Upsert IDs: 4, 5
|
|
upsertData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{4, 5}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{400, 500}}}}},
|
|
},
|
|
}
|
|
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
|
|
|
// Query result is empty, but schema is preserved
|
|
mockQueryResult := &milvuspb.QueryResults{Status: merr.Success(), FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{}}}}},
|
|
},
|
|
{
|
|
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{}}}}},
|
|
},
|
|
}}
|
|
|
|
task := &upsertTask{
|
|
ctx: context.Background(),
|
|
schema: schema,
|
|
req: &milvuspb.UpsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: uint32(numRows),
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: numRows,
|
|
},
|
|
},
|
|
},
|
|
node: &Proxy{},
|
|
}
|
|
|
|
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build()
|
|
defer mockRetrieve.UnPatch()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
assert.NoError(t, err)
|
|
|
|
// Verify delete PKs
|
|
deletePks := task.deletePKs.GetIntId().GetData()
|
|
assert.Empty(t, deletePks)
|
|
|
|
// Verify merged insert data
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
|
assert.NoError(t, err)
|
|
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
|
assert.NoError(t, err)
|
|
ids, err := parsePrimaryFieldData2IDs(idField)
|
|
assert.NoError(t, err)
|
|
insertPKs := ids.GetIntId().GetData()
|
|
assert.Equal(t, []int64{4, 5}, insertPKs)
|
|
|
|
var valueField *schemapb.FieldData
|
|
for _, f := range task.insertFieldData {
|
|
if f.GetFieldName() == "value" {
|
|
valueField = f
|
|
break
|
|
}
|
|
}
|
|
assert.NotNil(t, valueField)
|
|
assert.Equal(t, []int32{400, 500}, valueField.GetScalars().GetIntData().GetData())
|
|
}
|
|
|
|
func TestUpsertTask_queryPreExecute_PureUpdate(t *testing.T) {
|
|
// Schema for the test collection
|
|
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
|
Name: "test_merge_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
|
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
|
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
|
},
|
|
})
|
|
|
|
// Upsert IDs: 6, 7
|
|
upsertData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{600, 700}}}}},
|
|
},
|
|
}
|
|
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
|
|
|
// Query result for existing PKs: 6, 7
|
|
mockQueryResult := &milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{60, 70}}}}},
|
|
},
|
|
},
|
|
}
|
|
|
|
task := &upsertTask{
|
|
ctx: context.Background(),
|
|
schema: schema,
|
|
req: &milvuspb.UpsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: uint32(numRows),
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: numRows,
|
|
},
|
|
},
|
|
},
|
|
node: &Proxy{},
|
|
}
|
|
|
|
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, nil).Build()
|
|
defer mockRetrieve.UnPatch()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
assert.NoError(t, err)
|
|
|
|
// Verify delete PKs
|
|
deletePks := task.deletePKs.GetIntId().GetData()
|
|
assert.ElementsMatch(t, []int64{6, 7}, deletePks)
|
|
|
|
// Verify merged insert data
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
|
assert.NoError(t, err)
|
|
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
|
assert.NoError(t, err)
|
|
ids, err := parsePrimaryFieldData2IDs(idField)
|
|
assert.NoError(t, err)
|
|
insertPKs := ids.GetIntId().GetData()
|
|
assert.Equal(t, []int64{6, 7}, insertPKs)
|
|
|
|
var valueField *schemapb.FieldData
|
|
for _, f := range task.insertFieldData {
|
|
if f.GetFieldName() == "value" {
|
|
valueField = f
|
|
break
|
|
}
|
|
}
|
|
assert.NotNil(t, valueField)
|
|
assert.Equal(t, []int32{600, 700}, valueField.GetScalars().GetIntData().GetData())
|
|
}
|