mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
issue: #44320 This change adds deduplication logic to handle duplicate primary keys within a single upsert batch, keeping the last occurrence of each primary key. Key changes: - Add DeduplicateFieldData function to remove duplicate PKs from field data, supporting both Int64 and VarChar primary keys - Refactor fillFieldPropertiesBySchema into two separate functions: validateFieldDataColumns for validation and fillFieldPropertiesOnly for property filling, improving code clarity and reusability - Integrate deduplication logic in upsertTask.PreExecute to automatically deduplicate data before processing - Add comprehensive unit tests for deduplication with various PK types (Int64, VarChar) and field types (scalar, vector) - Add Python integration tests to verify end-to-end behavior --------- Signed-off-by: Wei Liu <wei.liu@zilliz.com>
1869 lines
60 KiB
Go
1869 lines
60 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/bytedance/mockey"
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/allocator"
|
|
grpcmixcoordclient "github.com/milvus-io/milvus/internal/distributed/mixcoord/client"
|
|
"github.com/milvus-io/milvus/internal/mocks"
|
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
|
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
|
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
|
"github.com/milvus-io/milvus/internal/util/segcore"
|
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
|
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
func TestUpsertTask_CheckAligned(t *testing.T) {
|
|
var err error
|
|
|
|
// passed NumRows is less than 0
|
|
case1 := upsertTask{
|
|
req: &milvuspb.UpsertRequest{
|
|
NumRows: 0,
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{},
|
|
},
|
|
},
|
|
}
|
|
case1.upsertMsg.InsertMsg.InsertRequest = &msgpb.InsertRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
|
),
|
|
CollectionName: case1.req.CollectionName,
|
|
PartitionName: case1.req.PartitionName,
|
|
FieldsData: case1.req.FieldsData,
|
|
NumRows: uint64(case1.req.NumRows),
|
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
|
}
|
|
|
|
err = case1.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// checkFieldsDataBySchema was already checked by TestUpsertTask_checkFieldsDataBySchema
|
|
|
|
boolFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}
|
|
int8FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}
|
|
int16FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}
|
|
int32FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}
|
|
int64FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}
|
|
floatFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float}
|
|
doubleFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Double}
|
|
floatVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
|
|
binaryVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BinaryVector}
|
|
varCharFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}
|
|
|
|
numRows := 20
|
|
dim := 128
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Name: "TestUpsertTask_checkRowNums",
|
|
Description: "TestUpsertTask_checkRowNums",
|
|
AutoID: false,
|
|
Fields: []*schemapb.FieldSchema{
|
|
boolFieldSchema,
|
|
int8FieldSchema,
|
|
int16FieldSchema,
|
|
int32FieldSchema,
|
|
int64FieldSchema,
|
|
floatFieldSchema,
|
|
doubleFieldSchema,
|
|
floatVectorFieldSchema,
|
|
binaryVectorFieldSchema,
|
|
varCharFieldSchema,
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
case2 := upsertTask{
|
|
req: &milvuspb.UpsertRequest{
|
|
NumRows: uint32(numRows),
|
|
FieldsData: []*schemapb.FieldData{},
|
|
},
|
|
rowIDs: testutils.GenerateInt64Array(numRows),
|
|
timestamps: testutils.GenerateUint64Array(numRows),
|
|
schema: schema,
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{},
|
|
},
|
|
},
|
|
}
|
|
|
|
// satisfied
|
|
case2.req.FieldsData = []*schemapb.FieldData{
|
|
newScalarFieldData(boolFieldSchema, "Bool", numRows),
|
|
newScalarFieldData(int8FieldSchema, "Int8", numRows),
|
|
newScalarFieldData(int16FieldSchema, "Int16", numRows),
|
|
newScalarFieldData(int32FieldSchema, "Int32", numRows),
|
|
newScalarFieldData(int64FieldSchema, "Int64", numRows),
|
|
newScalarFieldData(floatFieldSchema, "Float", numRows),
|
|
newScalarFieldData(doubleFieldSchema, "Double", numRows),
|
|
newFloatVectorFieldData("FloatVector", numRows, dim),
|
|
newBinaryVectorFieldData("BinaryVector", numRows, dim),
|
|
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
|
|
}
|
|
case2.upsertMsg.InsertMsg.InsertRequest = &msgpb.InsertRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
|
),
|
|
CollectionName: case2.req.CollectionName,
|
|
PartitionName: case2.req.PartitionName,
|
|
FieldsData: case2.req.FieldsData,
|
|
NumRows: uint64(case2.req.NumRows),
|
|
RowIDs: case2.rowIDs,
|
|
Timestamps: case2.timestamps,
|
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
|
}
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less bool data
|
|
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more bool data
|
|
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, nil, err)
|
|
|
|
// less int8 data
|
|
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int8 data
|
|
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less int16 data
|
|
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int16 data
|
|
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less int32 data
|
|
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int32 data
|
|
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less int64 data
|
|
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more int64 data
|
|
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less float data
|
|
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more float data
|
|
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, nil, err)
|
|
|
|
// less double data
|
|
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more double data
|
|
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, nil, err)
|
|
|
|
// less float vectors
|
|
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more float vectors
|
|
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less binary vectors
|
|
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more binary vectors
|
|
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
|
|
// less double data
|
|
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// more double data
|
|
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.Error(t, err)
|
|
// revert
|
|
case2.req.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
|
|
case2.upsertMsg.InsertMsg.FieldsData = case2.req.FieldsData
|
|
err = case2.upsertMsg.InsertMsg.CheckAligned()
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestUpsertTask(t *testing.T) {
|
|
t.Run("test getChannels", func(t *testing.T) {
|
|
collectionID := UniqueID(0)
|
|
collectionName := "col-0"
|
|
channels := []pChan{"mock-chan-0", "mock-chan-1"}
|
|
cache := NewMockCache(t)
|
|
cache.On("GetCollectionID",
|
|
mock.Anything, // context.Context
|
|
mock.AnythingOfType("string"),
|
|
mock.AnythingOfType("string"),
|
|
).Return(collectionID, nil)
|
|
globalMetaCache = cache
|
|
|
|
chMgr := NewMockChannelsMgr(t)
|
|
chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil)
|
|
ut := upsertTask{
|
|
ctx: context.Background(),
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: collectionName,
|
|
},
|
|
chMgr: chMgr,
|
|
}
|
|
err := ut.setChannels()
|
|
assert.NoError(t, err)
|
|
resChannels := ut.getChannels()
|
|
assert.ElementsMatch(t, channels, resChannels)
|
|
assert.ElementsMatch(t, channels, ut.pChannels)
|
|
})
|
|
}
|
|
|
|
func TestUpsertTaskForReplicate(t *testing.T) {
|
|
cache := globalMetaCache
|
|
defer func() { globalMetaCache = cache }()
|
|
mockCache := NewMockCache(t)
|
|
globalMetaCache = mockCache
|
|
ctx := context.Background()
|
|
|
|
t.Run("fail to get collection info", func(t *testing.T) {
|
|
ut := upsertTask{
|
|
ctx: ctx,
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: "col-0",
|
|
},
|
|
}
|
|
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, errors.New("foo")).Once()
|
|
err := ut.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("replicate mode", func(t *testing.T) {
|
|
ut := upsertTask{
|
|
ctx: ctx,
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: "col-0",
|
|
},
|
|
}
|
|
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
|
replicateID: "local-mac",
|
|
}, nil).Once()
|
|
err := ut.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestUpsertTask_Function(t *testing.T) {
|
|
paramtable.Init()
|
|
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
|
|
return map[string]string{
|
|
"mock.apikey": "mock",
|
|
}
|
|
}
|
|
ts := embedding.CreateOpenAIEmbeddingServer()
|
|
defer ts.Close()
|
|
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
|
|
return map[string]string{
|
|
"openai.url": ts.URL,
|
|
}
|
|
}
|
|
|
|
data := []*schemapb.FieldData{}
|
|
f1 := schemapb.FieldData{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldId: 100,
|
|
FieldName: "id",
|
|
IsDynamic: false,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: []int64{0, 1},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
data = append(data, &f1)
|
|
f2 := schemapb.FieldData{
|
|
Type: schemapb.DataType_VarChar,
|
|
FieldId: 101,
|
|
FieldName: "text",
|
|
IsDynamic: false,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: []string{"sentence", "sentence"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
data = append(data, &f2)
|
|
collectionName := "TestUpsertTask_function"
|
|
schema := &schemapb.CollectionSchema{
|
|
Name: collectionName,
|
|
Description: "TestUpsertTask_function",
|
|
AutoID: true,
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
|
|
{
|
|
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "max_length", Value: "200"},
|
|
},
|
|
},
|
|
{
|
|
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "dim", Value: "4"},
|
|
},
|
|
IsFunctionOutput: true,
|
|
},
|
|
},
|
|
Functions: []*schemapb.FunctionSchema{
|
|
{
|
|
Name: "test_function",
|
|
Type: schemapb.FunctionType_TextEmbedding,
|
|
InputFieldIds: []int64{101},
|
|
InputFieldNames: []string{"text"},
|
|
OutputFieldIds: []int64{102},
|
|
OutputFieldNames: []string{"vector"},
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: "provider", Value: "openai"},
|
|
{Key: "model_name", Value: "text-embedding-ada-002"},
|
|
{Key: "credential", Value: "mock"},
|
|
{Key: "dim", Value: "4"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
info := newSchemaInfo(schema)
|
|
collectionID := UniqueID(0)
|
|
cache := NewMockCache(t)
|
|
globalMetaCache = cache
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
rc := mocks.NewMockRootCoordClient(t)
|
|
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
|
|
Status: merr.Status(nil),
|
|
ID: collectionID,
|
|
Count: 10,
|
|
}, nil)
|
|
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
|
|
idAllocator.Start()
|
|
defer idAllocator.Close()
|
|
assert.NoError(t, err)
|
|
task := upsertTask{
|
|
ctx: context.Background(),
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: collectionName,
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
Base: commonpbutil.NewMsgBase(
|
|
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
|
|
),
|
|
CollectionName: collectionName,
|
|
DbName: "hooooooo",
|
|
Version: msgpb.InsertDataVersion_ColumnBased,
|
|
FieldsData: data,
|
|
NumRows: 2,
|
|
PartitionName: Params.CommonCfg.DefaultPartitionName.GetValue(),
|
|
},
|
|
},
|
|
},
|
|
idAllocator: idAllocator,
|
|
schema: info,
|
|
result: &milvuspb.MutationResult{},
|
|
}
|
|
err = genFunctionFields(task.ctx, task.upsertMsg.InsertMsg, task.schema, task.req.GetPartialUpdate())
|
|
assert.NoError(t, err)
|
|
err = task.insertPreExecute(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
// process failed
|
|
{
|
|
oldRows := task.upsertMsg.InsertMsg.InsertRequest.NumRows
|
|
task.upsertMsg.InsertMsg.InsertRequest.NumRows = 10000
|
|
err = task.insertPreExecute(ctx)
|
|
assert.Error(t, err)
|
|
task.upsertMsg.InsertMsg.InsertRequest.NumRows = oldRows
|
|
}
|
|
}
|
|
|
|
func TestUpsertTaskForSchemaMismatch(t *testing.T) {
|
|
cache := globalMetaCache
|
|
defer func() { globalMetaCache = cache }()
|
|
mockCache := NewMockCache(t)
|
|
globalMetaCache = mockCache
|
|
ctx := context.Background()
|
|
|
|
t.Run("schema ts mismatch", func(t *testing.T) {
|
|
ut := upsertTask{
|
|
ctx: ctx,
|
|
req: &milvuspb.UpsertRequest{
|
|
CollectionName: "col-0",
|
|
NumRows: 10,
|
|
},
|
|
schemaTimestamp: 99,
|
|
}
|
|
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil)
|
|
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
|
|
updateTimestamp: 100,
|
|
}, nil)
|
|
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
|
|
err := ut.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
assert.ErrorIs(t, err, merr.ErrCollectionSchemaMismatch)
|
|
})
|
|
}
|
|
|
|
// Helper function to create test updateTask
|
|
func createTestUpdateTask() *upsertTask {
|
|
mcClient := &grpcmixcoordclient.Client{}
|
|
|
|
upsertTask := &upsertTask{
|
|
baseTask: baseTask{},
|
|
Condition: NewTaskCondition(context.Background()),
|
|
req: &milvuspb.UpsertRequest{
|
|
DbName: "test_db",
|
|
CollectionName: "test_collection",
|
|
PartitionName: "_default",
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "name",
|
|
FieldId: 102,
|
|
Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{Data: []string{"test1", "test2", "test3"}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
FieldId: 101,
|
|
Type: schemapb.DataType_FloatVector,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: 128,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{Data: make([]float32, 384)}, // 3 * 128
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
NumRows: 3,
|
|
},
|
|
ctx: context.Background(),
|
|
schema: createTestSchema(),
|
|
collectionID: 1001,
|
|
node: &Proxy{
|
|
mixCoord: mcClient,
|
|
lbPolicy: shardclient.NewLBPolicyImpl(nil),
|
|
},
|
|
}
|
|
|
|
return upsertTask
|
|
}
|
|
|
|
// Helper function to create test schema
|
|
func createTestSchema() *schemaInfo {
|
|
schema := &schemapb.CollectionSchema{
|
|
Name: "test_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
},
|
|
{
|
|
FieldID: 101,
|
|
Name: "vector",
|
|
DataType: schemapb.DataType_FloatVector,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "dim", Value: "128"},
|
|
},
|
|
},
|
|
{
|
|
FieldID: 102,
|
|
Name: "name",
|
|
DataType: schemapb.DataType_VarChar,
|
|
},
|
|
},
|
|
}
|
|
return newSchemaInfo(schema)
|
|
}
|
|
|
|
func TestRetrieveByPKs_Success(t *testing.T) {
|
|
mockey.PatchConvey("TestRetrieveByPKs_Success", t, func() {
|
|
// Setup mocks
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(validatePartitionTag).Return(nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetPartitionID).Return(int64(1002), nil).Build()
|
|
|
|
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
|
|
|
|
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, segcore.StorageCost{}, nil).Build()
|
|
|
|
globalMetaCache = &MetaCache{}
|
|
mockey.Mock(globalMetaCache.GetPartitionID).Return(int64(1002), nil).Build()
|
|
|
|
// Execute test
|
|
task := createTestUpdateTask()
|
|
task.partitionKeyMode = false
|
|
task.upsertMsg = &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
PartitionName: "_default",
|
|
},
|
|
},
|
|
DeleteMsg: &msgstream.DeleteMsg{
|
|
DeleteRequest: &msgpb.DeleteRequest{
|
|
PartitionName: "_default",
|
|
},
|
|
},
|
|
}
|
|
|
|
ids := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
}
|
|
|
|
result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
|
|
|
// Verify results
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
assert.Equal(t, commonpb.ErrorCode_Success, result.Status.ErrorCode)
|
|
assert.Len(t, result.FieldsData, 1)
|
|
})
|
|
}
|
|
|
|
func TestRetrieveByPKs_GetPrimaryFieldSchemaError(t *testing.T) {
|
|
mockey.PatchConvey("TestRetrieveByPKs_GetPrimaryFieldSchemaError", t, func() {
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
ids := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
}
|
|
|
|
result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
|
|
|
assert.Error(t, err)
|
|
assert.Nil(t, result)
|
|
assert.Contains(t, err.Error(), "primary field not found")
|
|
})
|
|
}
|
|
|
|
func TestRetrieveByPKs_PartitionKeyMode(t *testing.T) {
|
|
mockey.PatchConvey("TestRetrieveByPKs_PartitionKeyMode", t, func() {
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(planparserv2.CreateRequeryPlan).Return(&planpb.PlanNode{}).Build()
|
|
|
|
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{},
|
|
}, segcore.StorageCost{}, nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.partitionKeyMode = true
|
|
|
|
ids := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
}
|
|
|
|
result, _, err := retrieveByPKs(context.Background(), task, ids, []string{"*"})
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, result)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_Success(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_Success", t, func() {
|
|
// Setup mocks
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
|
},
|
|
},
|
|
},
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: []int64{1, 2, 3}},
|
|
},
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetSizeOfIDs).Return(3).Build()
|
|
|
|
mockey.Mock(retrieveByPKs).Return(&milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{Data: []int64{1, 2}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "name",
|
|
FieldId: 102,
|
|
Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
FieldId: 101,
|
|
Type: schemapb.DataType_FloatVector,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: 128,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{Data: make([]float32, 256)}, // 2 * 128
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}, segcore.StorageCost{}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.NewIDsChecker).Return(&typeutil.IDsChecker{}, nil).Build()
|
|
|
|
// Execute test
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
task.upsertMsg = &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
},
|
|
{
|
|
FieldName: "name",
|
|
FieldId: 102,
|
|
Type: schemapb.DataType_VarChar,
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
FieldId: 101,
|
|
Type: schemapb.DataType_FloatVector,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
// Verify results
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, task.deletePKs)
|
|
assert.NotNil(t, task.insertFieldData)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldSchemaError", t, func() {
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("primary field not found")
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(nil, expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "primary field not found")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_GetPrimaryFieldDataError", t, func() {
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("primary field data not found")
|
|
mockey.Mock(typeutil.GetPrimaryFieldData).Return(nil, expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "must assign pk when upsert")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_queryPreExecute_EmptyOldIDs(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_queryPreExecute_EmptyOldIDs", t, func() {
|
|
mockey.Mock(typeutil.GetPrimaryFieldSchema).Return(&schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "id",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetPrimaryFieldData).Return(&schemapb.FieldData{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock(parsePrimaryFieldData2IDs).Return(&schemapb.IDs{}, nil).Build()
|
|
|
|
mockey.Mock(typeutil.GetSizeOfIDs).Return(0).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.schema = createTestSchema()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, task.deletePKs)
|
|
assert.Equal(t, task.req.GetFieldsData(), task.insertFieldData)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_Success(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_Success", t, func() {
|
|
// Setup mocks
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
|
|
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
|
|
|
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
|
name: "_default",
|
|
}, nil).Build()
|
|
|
|
mockey.Mock((*upsertTask).queryPreExecute).Return(nil).Build()
|
|
|
|
mockey.Mock((*upsertTask).insertPreExecute).Return(nil).Build()
|
|
|
|
mockey.Mock((*upsertTask).deletePreExecute).Return(nil).Build()
|
|
|
|
// Execute test
|
|
task := createTestUpdateTask()
|
|
task.req.PartialUpdate = true
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
// Verify results
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, task.result)
|
|
assert.Equal(t, int64(1001), task.collectionID)
|
|
assert.NotNil(t, task.schema)
|
|
assert.NotNil(t, task.upsertMsg)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_ReplicateIDError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_ReplicateIDError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("replica1", nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "can't operate on the collection under standby mode")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_GetCollectionIDError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_GetCollectionIDError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
|
|
expectedErr := merr.WrapErrCollectionNotFound("test_collection")
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(0), expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_PartitionKeyModeError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_PartitionKeyModeError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
|
|
mockey.Mock(isPartitionKeyMode).Return(true, nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.req.PartitionName = "custom_partition" // This should cause error in partition key mode
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "not support manually specifying the partition names if partition key mode is used")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_InvalidNumRows(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_InvalidNumRows", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
|
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
|
name: "_default",
|
|
}, nil).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.req.FieldsData = []*schemapb.FieldData{}
|
|
task.req.NumRows = 0 // Invalid num_rows
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "invalid num_rows")
|
|
})
|
|
}
|
|
|
|
func TestUpdateTask_PreExecute_QueryPreExecuteError(t *testing.T) {
|
|
mockey.PatchConvey("TestUpdateTask_PreExecute_QueryPreExecuteError", t, func() {
|
|
globalMetaCache = &MetaCache{}
|
|
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{
|
|
updateTimestamp: 12345,
|
|
}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).Return(createTestSchema(), nil).Build()
|
|
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
|
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{
|
|
name: "_default",
|
|
}, nil).Build()
|
|
|
|
expectedErr := merr.WrapErrParameterInvalidMsg("query pre-execute failed")
|
|
mockey.Mock((*upsertTask).queryPreExecute).Return(expectedErr).Build()
|
|
|
|
task := createTestUpdateTask()
|
|
task.req.PartialUpdate = true
|
|
|
|
err := task.PreExecute(context.Background())
|
|
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "query pre-execute failed")
|
|
})
|
|
}
|
|
|
|
func TestUpsertTask_queryPreExecute_MixLogic(t *testing.T) {
|
|
// Schema for the test collection
|
|
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
|
Name: "test_merge_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
|
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
|
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
|
},
|
|
})
|
|
|
|
// Upsert IDs: 1 (update), 2 (update), 3 (insert)
|
|
upsertData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2, 3}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{100, 200, 300}}}}},
|
|
},
|
|
}
|
|
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
|
|
|
// Query result for existing PKs: 1, 2
|
|
mockQueryResult := &milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{10, 20}}}}},
|
|
},
|
|
{
|
|
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}}}}},
|
|
ValidData: []bool{true, true},
|
|
},
|
|
},
|
|
}
|
|
|
|
task := &upsertTask{
|
|
ctx: context.Background(),
|
|
schema: schema,
|
|
req: &milvuspb.UpsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: uint32(numRows),
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: numRows,
|
|
},
|
|
},
|
|
},
|
|
node: &Proxy{},
|
|
}
|
|
|
|
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build()
|
|
defer mockRetrieve.UnPatch()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
assert.NoError(t, err)
|
|
|
|
// Verify delete PKs
|
|
deletePks := task.deletePKs.GetIntId().GetData()
|
|
assert.ElementsMatch(t, []int64{1, 2}, deletePks)
|
|
|
|
// Verify merged insert data
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
|
assert.NoError(t, err)
|
|
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
|
assert.NoError(t, err)
|
|
ids, err := parsePrimaryFieldData2IDs(idField)
|
|
assert.NoError(t, err)
|
|
insertPKs := ids.GetIntId().GetData()
|
|
assert.Equal(t, []int64{1, 2, 3}, insertPKs)
|
|
|
|
var valueField *schemapb.FieldData
|
|
for _, f := range task.insertFieldData {
|
|
if f.GetFieldName() == "value" {
|
|
valueField = f
|
|
break
|
|
}
|
|
}
|
|
assert.NotNil(t, valueField)
|
|
assert.Equal(t, []int32{100, 200, 300}, valueField.GetScalars().GetIntData().GetData())
|
|
}
|
|
|
|
func TestUpsertTask_queryPreExecute_PureInsert(t *testing.T) {
|
|
// Schema for the test collection
|
|
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
|
Name: "test_merge_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
|
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
|
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
|
},
|
|
})
|
|
|
|
// Upsert IDs: 4, 5
|
|
upsertData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{4, 5}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{400, 500}}}}},
|
|
},
|
|
}
|
|
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
|
|
|
// Query result is empty, but schema is preserved
|
|
mockQueryResult := &milvuspb.QueryResults{Status: merr.Success(), FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{}}}}},
|
|
},
|
|
{
|
|
FieldName: "extra", FieldId: 102, Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{}}}}},
|
|
},
|
|
}}
|
|
|
|
task := &upsertTask{
|
|
ctx: context.Background(),
|
|
schema: schema,
|
|
req: &milvuspb.UpsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: uint32(numRows),
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: numRows,
|
|
},
|
|
},
|
|
},
|
|
node: &Proxy{},
|
|
}
|
|
|
|
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build()
|
|
defer mockRetrieve.UnPatch()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
assert.NoError(t, err)
|
|
|
|
// Verify delete PKs
|
|
deletePks := task.deletePKs.GetIntId().GetData()
|
|
assert.Empty(t, deletePks)
|
|
|
|
// Verify merged insert data
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
|
assert.NoError(t, err)
|
|
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
|
assert.NoError(t, err)
|
|
ids, err := parsePrimaryFieldData2IDs(idField)
|
|
assert.NoError(t, err)
|
|
insertPKs := ids.GetIntId().GetData()
|
|
assert.Equal(t, []int64{4, 5}, insertPKs)
|
|
|
|
var valueField *schemapb.FieldData
|
|
for _, f := range task.insertFieldData {
|
|
if f.GetFieldName() == "value" {
|
|
valueField = f
|
|
break
|
|
}
|
|
}
|
|
assert.NotNil(t, valueField)
|
|
assert.Equal(t, []int32{400, 500}, valueField.GetScalars().GetIntData().GetData())
|
|
}
|
|
|
|
func TestUpsertTask_queryPreExecute_PureUpdate(t *testing.T) {
|
|
// Schema for the test collection
|
|
schema := newSchemaInfo(&schemapb.CollectionSchema{
|
|
Name: "test_merge_collection",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "id", IsPrimaryKey: true, DataType: schemapb.DataType_Int64},
|
|
{FieldID: 101, Name: "value", DataType: schemapb.DataType_Int32},
|
|
{FieldID: 102, Name: "extra", DataType: schemapb.DataType_VarChar, Nullable: true},
|
|
},
|
|
})
|
|
|
|
// Upsert IDs: 6, 7
|
|
upsertData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{600, 700}}}}},
|
|
},
|
|
}
|
|
numRows := uint64(len(upsertData[0].GetScalars().GetLongData().GetData()))
|
|
|
|
// Query result for existing PKs: 6, 7
|
|
mockQueryResult := &milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id", FieldId: 100, Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{6, 7}}}}},
|
|
},
|
|
{
|
|
FieldName: "value", FieldId: 101, Type: schemapb.DataType_Int32,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_IntData{IntData: &schemapb.IntArray{Data: []int32{60, 70}}}}},
|
|
},
|
|
},
|
|
}
|
|
|
|
task := &upsertTask{
|
|
ctx: context.Background(),
|
|
schema: schema,
|
|
req: &milvuspb.UpsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: uint32(numRows),
|
|
},
|
|
upsertMsg: &msgstream.UpsertMsg{
|
|
InsertMsg: &msgstream.InsertMsg{
|
|
InsertRequest: &msgpb.InsertRequest{
|
|
FieldsData: upsertData,
|
|
NumRows: numRows,
|
|
},
|
|
},
|
|
},
|
|
node: &Proxy{},
|
|
}
|
|
|
|
mockRetrieve := mockey.Mock(retrieveByPKs).Return(mockQueryResult, segcore.StorageCost{}, nil).Build()
|
|
defer mockRetrieve.UnPatch()
|
|
|
|
err := task.queryPreExecute(context.Background())
|
|
assert.NoError(t, err)
|
|
|
|
// Verify delete PKs
|
|
deletePks := task.deletePKs.GetIntId().GetData()
|
|
assert.ElementsMatch(t, []int64{6, 7}, deletePks)
|
|
|
|
// Verify merged insert data
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema.CollectionSchema)
|
|
assert.NoError(t, err)
|
|
idField, err := typeutil.GetPrimaryFieldData(task.insertFieldData, primaryFieldSchema)
|
|
assert.NoError(t, err)
|
|
ids, err := parsePrimaryFieldData2IDs(idField)
|
|
assert.NoError(t, err)
|
|
insertPKs := ids.GetIntId().GetData()
|
|
assert.Equal(t, []int64{6, 7}, insertPKs)
|
|
|
|
var valueField *schemapb.FieldData
|
|
for _, f := range task.insertFieldData {
|
|
if f.GetFieldName() == "value" {
|
|
valueField = f
|
|
break
|
|
}
|
|
}
|
|
assert.NotNil(t, valueField)
|
|
assert.Equal(t, []int32{600, 700}, valueField.GetScalars().GetIntData().GetData())
|
|
}
|
|
|
|
// Test ToCompressedFormatNullable for Geometry and Timestamptz types
|
|
func TestToCompressedFormatNullable_GeometryAndTimestamptz(t *testing.T) {
|
|
t.Run("timestamptz with null values", func(t *testing.T) {
|
|
field := &schemapb.FieldData{
|
|
Type: schemapb.DataType_Timestamptz,
|
|
FieldName: "timestamp_field",
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_TimestamptzData{
|
|
TimestamptzData: &schemapb.TimestamptzArray{
|
|
Data: []int64{1000, 0, 3000, 0},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
ValidData: []bool{true, false, true, false},
|
|
}
|
|
|
|
err := ToCompressedFormatNullable(field)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []int64{1000, 3000}, field.GetScalars().GetTimestamptzData().GetData())
|
|
assert.Equal(t, []bool{true, false, true, false}, field.ValidData)
|
|
})
|
|
|
|
t.Run("geometry WKT with null values", func(t *testing.T) {
|
|
field := &schemapb.FieldData{
|
|
Type: schemapb.DataType_Geometry,
|
|
FieldName: "geometry_field",
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_GeometryWktData{
|
|
GeometryWktData: &schemapb.GeometryWktArray{
|
|
Data: []string{"POINT (1 2)", "", "POINT (5 6)"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
ValidData: []bool{true, false, true},
|
|
}
|
|
|
|
err := ToCompressedFormatNullable(field)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, []string{"POINT (1 2)", "POINT (5 6)"}, field.GetScalars().GetGeometryWktData().GetData())
|
|
})
|
|
|
|
t.Run("geometry WKB with null values", func(t *testing.T) {
|
|
field := &schemapb.FieldData{
|
|
Type: schemapb.DataType_Geometry,
|
|
FieldName: "geometry_field",
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_GeometryData{
|
|
GeometryData: &schemapb.GeometryArray{
|
|
Data: [][]byte{{0x01, 0x02}, nil, {0x05, 0x06}},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
ValidData: []bool{true, false, true},
|
|
}
|
|
|
|
err := ToCompressedFormatNullable(field)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, [][]byte{{0x01, 0x02}, {0x05, 0x06}}, field.GetScalars().GetGeometryData().GetData())
|
|
})
|
|
}
|
|
|
|
// Test GenNullableFieldData for Geometry and Timestamptz types
|
|
func TestGenNullableFieldData_GeometryAndTimestamptz(t *testing.T) {
|
|
t.Run("generate timestamptz nullable field", func(t *testing.T) {
|
|
field := &schemapb.FieldSchema{
|
|
FieldID: 100,
|
|
Name: "timestamp_field",
|
|
DataType: schemapb.DataType_Timestamptz,
|
|
IsDynamic: false,
|
|
}
|
|
|
|
upsertIDSize := 5
|
|
fieldData, err := GenNullableFieldData(field, upsertIDSize)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, fieldData)
|
|
assert.Equal(t, int64(100), fieldData.FieldId)
|
|
assert.Equal(t, "timestamp_field", fieldData.FieldName)
|
|
assert.Len(t, fieldData.ValidData, upsertIDSize)
|
|
assert.Len(t, fieldData.GetScalars().GetTimestamptzData().GetData(), upsertIDSize)
|
|
})
|
|
|
|
t.Run("generate geometry nullable field", func(t *testing.T) {
|
|
field := &schemapb.FieldSchema{
|
|
FieldID: 101,
|
|
Name: "geometry_field",
|
|
DataType: schemapb.DataType_Geometry,
|
|
IsDynamic: false,
|
|
}
|
|
|
|
upsertIDSize := 3
|
|
fieldData, err := GenNullableFieldData(field, upsertIDSize)
|
|
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, fieldData)
|
|
assert.Equal(t, int64(101), fieldData.FieldId)
|
|
assert.Equal(t, "geometry_field", fieldData.FieldName)
|
|
assert.Len(t, fieldData.ValidData, upsertIDSize)
|
|
assert.Len(t, fieldData.GetScalars().GetGeometryWktData().GetData(), upsertIDSize)
|
|
})
|
|
}
|
|
|
|
func TestUpsertTask_PlanNamespace_AfterPreExecute(t *testing.T) {
|
|
mockey.PatchConvey("TestUpsertTask_PlanNamespace_AfterPreExecute", t, func() {
|
|
// Setup global meta cache and common mocks
|
|
globalMetaCache = &MetaCache{}
|
|
mockey.Mock(GetReplicateID).Return("", nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionID).Return(int64(1001), nil).Build()
|
|
mockey.Mock((*MetaCache).GetCollectionInfo).Return(&collectionInfo{updateTimestamp: 12345}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetPartitionInfo).Return(&partitionInfo{name: "_default"}, nil).Build()
|
|
mockey.Mock((*MetaCache).GetPartitionID).Return(int64(1002), nil).Build()
|
|
mockey.Mock(isPartitionKeyMode).Return(false, nil).Build()
|
|
mockey.Mock(validatePartitionTag).Return(nil).Build()
|
|
|
|
// Schema with namespace enabled
|
|
mockey.Mock((*MetaCache).GetCollectionSchema).To(func(_ *MetaCache, _ context.Context, _ string, _ string) (*schemaInfo, error) {
|
|
info := createTestSchema()
|
|
info.CollectionSchema.Properties = append(info.CollectionSchema.Properties, &commonpb.KeyValuePair{Key: common.NamespaceEnabledKey, Value: "true"})
|
|
return info, nil
|
|
}).Build()
|
|
|
|
// Capture plan to verify namespace
|
|
var capturedPlan *planpb.PlanNode
|
|
mockey.Mock(planparserv2.CreateRequeryPlan).To(func(_ *schemapb.FieldSchema, _ *schemapb.IDs) *planpb.PlanNode {
|
|
capturedPlan = &planpb.PlanNode{}
|
|
return capturedPlan
|
|
}).Build()
|
|
|
|
// Mock query to return a valid result for queryPreExecute merge path
|
|
mockey.Mock((*Proxy).query).Return(&milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
FieldId: 100,
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_LongData{LongData: &schemapb.LongArray{Data: []int64{1, 2}}}}},
|
|
},
|
|
{
|
|
FieldName: "name",
|
|
FieldId: 102,
|
|
Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{Scalars: &schemapb.ScalarField{Data: &schemapb.ScalarField_StringData{StringData: &schemapb.StringArray{Data: []string{"old1", "old2"}}}}},
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
FieldId: 101,
|
|
Type: schemapb.DataType_FloatVector,
|
|
Field: &schemapb.FieldData_Vectors{Vectors: &schemapb.VectorField{Dim: 128, Data: &schemapb.VectorField_FloatVector{FloatVector: &schemapb.FloatArray{Data: make([]float32, 256)}}}},
|
|
},
|
|
},
|
|
}, segcore.StorageCost{}, nil).Build()
|
|
|
|
// Build task
|
|
task := createTestUpdateTask()
|
|
ns := "ns-1"
|
|
task.req.PartialUpdate = true
|
|
task.req.Namespace = &ns
|
|
|
|
// Skip insert/delete heavy logic
|
|
mockey.Mock((*upsertTask).insertPreExecute).Return(nil).Build()
|
|
mockey.Mock((*upsertTask).deletePreExecute).Return(nil).Build()
|
|
|
|
err := task.PreExecute(context.Background())
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, capturedPlan)
|
|
assert.NotNil(t, capturedPlan.Namespace)
|
|
assert.Equal(t, *task.req.Namespace, *capturedPlan.Namespace)
|
|
})
|
|
}
|
|
|
|
func TestUpsertTask_Deduplicate_Int64PK(t *testing.T) {
|
|
// Test deduplication with Int64 primary key
|
|
primaryFieldSchema := &schemapb.FieldSchema{
|
|
Name: "id",
|
|
FieldID: 100,
|
|
DataType: schemapb.DataType_Int64,
|
|
IsPrimaryKey: true,
|
|
}
|
|
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
primaryFieldSchema,
|
|
{
|
|
Name: "float_field",
|
|
FieldID: 101,
|
|
DataType: schemapb.DataType_Float,
|
|
},
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
// Create field data with duplicate IDs: [1, 2, 3, 2, 1]
|
|
// Expected to keep last occurrence of each: [3, 2, 1] (indices 2, 3, 4)
|
|
fieldsData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3, 2, 1},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "float_field",
|
|
Type: schemapb.DataType_Float,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_FloatData{
|
|
FloatData: &schemapb.FloatArray{
|
|
Data: []float32{1.1, 2.2, 3.3, 2.4, 1.5},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint32(3), newNumRows)
|
|
assert.Equal(t, 2, len(deduplicatedFields))
|
|
|
|
// Check deduplicated primary keys
|
|
pkField := deduplicatedFields[0]
|
|
pkData := pkField.GetScalars().GetLongData().GetData()
|
|
assert.Equal(t, 3, len(pkData))
|
|
assert.Equal(t, []int64{3, 2, 1}, pkData)
|
|
|
|
// Check corresponding float values (should be 3.3, 2.4, 1.5)
|
|
floatField := deduplicatedFields[1]
|
|
floatData := floatField.GetScalars().GetFloatData().GetData()
|
|
assert.Equal(t, 3, len(floatData))
|
|
assert.Equal(t, []float32{3.3, 2.4, 1.5}, floatData)
|
|
}
|
|
|
|
func TestUpsertTask_Deduplicate_VarCharPK(t *testing.T) {
|
|
// Test deduplication with VarChar primary key
|
|
primaryFieldSchema := &schemapb.FieldSchema{
|
|
Name: "id",
|
|
FieldID: 100,
|
|
DataType: schemapb.DataType_VarChar,
|
|
IsPrimaryKey: true,
|
|
}
|
|
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
primaryFieldSchema,
|
|
{
|
|
Name: "int_field",
|
|
FieldID: 101,
|
|
DataType: schemapb.DataType_Int64,
|
|
},
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
// Create field data with duplicate IDs: ["a", "b", "c", "b", "a"]
|
|
// Expected to keep last occurrence of each: ["c", "b", "a"] (indices 2, 3, 4)
|
|
fieldsData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
Type: schemapb.DataType_VarChar,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: []string{"a", "b", "c", "b", "a"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "int_field",
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: []int64{100, 200, 300, 201, 101},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint32(3), newNumRows)
|
|
assert.Equal(t, 2, len(deduplicatedFields))
|
|
|
|
// Check deduplicated primary keys
|
|
pkField := deduplicatedFields[0]
|
|
pkData := pkField.GetScalars().GetStringData().GetData()
|
|
assert.Equal(t, 3, len(pkData))
|
|
assert.Equal(t, []string{"c", "b", "a"}, pkData)
|
|
|
|
// Check corresponding int64 values (should be 300, 201, 101)
|
|
int64Field := deduplicatedFields[1]
|
|
int64Data := int64Field.GetScalars().GetLongData().GetData()
|
|
assert.Equal(t, 3, len(int64Data))
|
|
assert.Equal(t, []int64{300, 201, 101}, int64Data)
|
|
}
|
|
|
|
func TestUpsertTask_Deduplicate_NoDuplicates(t *testing.T) {
|
|
// Test with no duplicates - should return original data
|
|
primaryFieldSchema := &schemapb.FieldSchema{
|
|
Name: "id",
|
|
FieldID: 100,
|
|
DataType: schemapb.DataType_Int64,
|
|
IsPrimaryKey: true,
|
|
}
|
|
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
primaryFieldSchema,
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
fieldsData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3, 4, 5},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint32(5), newNumRows)
|
|
assert.Equal(t, 1, len(deduplicatedFields))
|
|
|
|
// Should be unchanged
|
|
pkField := deduplicatedFields[0]
|
|
pkData := pkField.GetScalars().GetLongData().GetData()
|
|
assert.Equal(t, []int64{1, 2, 3, 4, 5}, pkData)
|
|
}
|
|
|
|
func TestUpsertTask_Deduplicate_WithVector(t *testing.T) {
|
|
// Test deduplication with vector field
|
|
primaryFieldSchema := &schemapb.FieldSchema{
|
|
Name: "id",
|
|
FieldID: 100,
|
|
DataType: schemapb.DataType_Int64,
|
|
IsPrimaryKey: true,
|
|
}
|
|
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
primaryFieldSchema,
|
|
{
|
|
Name: "vector",
|
|
FieldID: 101,
|
|
DataType: schemapb.DataType_FloatVector,
|
|
},
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
dim := 4
|
|
// Create field data with duplicate IDs: [1, 2, 1]
|
|
// Expected to keep indices [1, 2] (last occurrence of 2, last occurrence of 1)
|
|
fieldsData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "id",
|
|
Type: schemapb.DataType_Int64,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 1},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
FieldName: "vector",
|
|
Type: schemapb.DataType_FloatVector,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: int64(dim),
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: []float32{
|
|
1.0, 1.1, 1.2, 1.3, // vector for ID 1 (first occurrence)
|
|
2.0, 2.1, 2.2, 2.3, // vector for ID 2
|
|
1.4, 1.5, 1.6, 1.7, // vector for ID 1 (second occurrence - keep this)
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint32(2), newNumRows)
|
|
assert.Equal(t, 2, len(deduplicatedFields))
|
|
|
|
// Check deduplicated primary keys
|
|
pkField := deduplicatedFields[0]
|
|
pkData := pkField.GetScalars().GetLongData().GetData()
|
|
assert.Equal(t, 2, len(pkData))
|
|
assert.Equal(t, []int64{2, 1}, pkData)
|
|
|
|
// Check corresponding vector (should keep vectors for ID 2 and ID 1's last occurrence)
|
|
vectorField := deduplicatedFields[1]
|
|
vectorData := vectorField.GetVectors().GetFloatVector().GetData()
|
|
assert.Equal(t, 8, len(vectorData)) // 2 vectors * 4 dimensions
|
|
expectedVector := []float32{
|
|
2.0, 2.1, 2.2, 2.3, // vector for ID 2
|
|
1.4, 1.5, 1.6, 1.7, // vector for ID 1 (last occurrence)
|
|
}
|
|
assert.Equal(t, expectedVector, vectorData)
|
|
}
|
|
|
|
func TestUpsertTask_Deduplicate_EmptyData(t *testing.T) {
|
|
// Test with empty data
|
|
primaryFieldSchema := &schemapb.FieldSchema{
|
|
Name: "id",
|
|
FieldID: 100,
|
|
DataType: schemapb.DataType_Int64,
|
|
IsPrimaryKey: true,
|
|
}
|
|
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
primaryFieldSchema,
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
fieldsData := []*schemapb.FieldData{}
|
|
|
|
deduplicatedFields, newNumRows, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, uint32(0), newNumRows)
|
|
assert.Equal(t, 0, len(deduplicatedFields))
|
|
}
|
|
|
|
func TestUpsertTask_Deduplicate_MissingPrimaryKey(t *testing.T) {
|
|
// Test with missing primary key field
|
|
primaryFieldSchema := &schemapb.FieldSchema{
|
|
Name: "id",
|
|
FieldID: 100,
|
|
DataType: schemapb.DataType_Int64,
|
|
IsPrimaryKey: true,
|
|
}
|
|
|
|
collSchema := &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
primaryFieldSchema,
|
|
{
|
|
Name: "other_field",
|
|
FieldID: 101,
|
|
DataType: schemapb.DataType_Float,
|
|
},
|
|
},
|
|
}
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
fieldsData := []*schemapb.FieldData{
|
|
{
|
|
FieldName: "other_field",
|
|
Type: schemapb.DataType_Float,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_FloatData{
|
|
FloatData: &schemapb.FloatArray{
|
|
Data: []float32{1.1, 2.2},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
_, _, err := DeduplicateFieldData(primaryFieldSchema, fieldsData, schema)
|
|
assert.Error(t, err)
|
|
// validateFieldDataColumns will fail first due to column count mismatch
|
|
// or the function will fail when trying to find primary key
|
|
assert.True(t, err != nil)
|
|
}
|