milvus/internal/proxy/task_insert_test.go
cai.zhang 19346fa389
feat: Geospatial Data Type and GIS Function support for milvus (#44547)
issue: #43427

This pr's main goal is merge #37417 to milvus 2.5 without conflicts.

# Main Goals

1. Create and describe collections with geospatial type
2. Insert geospatial data into the insert binlog
3. Load segments containing geospatial data into memory
4. Enable query and search can display  geospatial data
5. Support using GIS funtions like ST_EQUALS in query
6. Support R-Tree index for geometry type

# Solution

1. **Add Type**: Modify the Milvus core by adding a Geospatial type in
both the C++ and Go code layers, defining the Geospatial data structure
and the corresponding interfaces.
2. **Dependency Libraries**: Introduce necessary geospatial data
processing libraries. In the C++ source code, use Conan package
management to include the GDAL library. In the Go source code, add the
go-geom library to the go.mod file.
3. **Protocol Interface**: Revise the Milvus protocol to provide
mechanisms for Geospatial message serialization and deserialization.
4. **Data Pipeline**: Facilitate interaction between the client and
proxy using the WKT format for geospatial data. The proxy will convert
all data into WKB format for downstream processing, providing column
data interfaces, segment encapsulation, segment loading, payload
writing, and cache block management.
5. **Query Operators**: Implement simple display and support for filter
queries. Initially, focus on filtering based on spatial relationships
for a single column of geospatial literal values, providing parsing and
execution for query expressions.Now only support brutal search
7. **Client Modification**: Enable the client to handle user input for
geospatial data and facilitate end-to-end testing.Check the modification
in pymilvus.

---------

Signed-off-by: Yinwei Li <yinwei.li@zilliz.com>
Signed-off-by: Cai Zhang <cai.zhang@zilliz.com>
Co-authored-by: ZhuXi <150327960+Yinwei-Yu@users.noreply.github.com>
2025-09-28 19:43:05 +08:00

729 lines
24 KiB
Go

package proxy
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/rootcoordpb"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
)
func TestInsertTask_CheckAligned(t *testing.T) {
var err error
// passed NumRows is less than 0
case1 := insertTask{
insertMsg: &BaseInsertTask{
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
NumRows: 0,
},
},
}
err = case1.insertMsg.CheckAligned()
assert.NoError(t, err)
// checkFieldsDataBySchema was already checked by TestInsertTask_checkFieldsDataBySchema
boolFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Bool}
int8FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int8}
int16FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int16}
int32FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int32}
int64FieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Int64}
floatFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float}
doubleFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Double}
floatVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_FloatVector}
binaryVectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BinaryVector}
float16VectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Float16Vector}
bfloat16VectorFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_BFloat16Vector}
varCharFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_VarChar}
geometryFieldSchema := &schemapb.FieldSchema{DataType: schemapb.DataType_Geometry}
numRows := 20
dim := 128
case2 := insertTask{
insertMsg: &BaseInsertTask{
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
Version: msgpb.InsertDataVersion_ColumnBased,
RowIDs: testutils.GenerateInt64Array(numRows),
Timestamps: testutils.GenerateUint64Array(numRows),
},
},
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkRowNums",
Description: "TestInsertTask_checkRowNums",
AutoID: false,
Fields: []*schemapb.FieldSchema{
boolFieldSchema,
int8FieldSchema,
int16FieldSchema,
int32FieldSchema,
int64FieldSchema,
floatFieldSchema,
doubleFieldSchema,
floatVectorFieldSchema,
binaryVectorFieldSchema,
float16VectorFieldSchema,
bfloat16VectorFieldSchema,
varCharFieldSchema,
geometryFieldSchema,
},
},
}
// satisfied
case2.insertMsg.NumRows = uint64(numRows)
case2.insertMsg.FieldsData = []*schemapb.FieldData{
newScalarFieldData(boolFieldSchema, "Bool", numRows),
newScalarFieldData(int8FieldSchema, "Int8", numRows),
newScalarFieldData(int16FieldSchema, "Int16", numRows),
newScalarFieldData(int32FieldSchema, "Int32", numRows),
newScalarFieldData(int64FieldSchema, "Int64", numRows),
newScalarFieldData(floatFieldSchema, "Float", numRows),
newScalarFieldData(doubleFieldSchema, "Double", numRows),
newFloatVectorFieldData("FloatVector", numRows, dim),
newBinaryVectorFieldData("BinaryVector", numRows, dim),
newFloat16VectorFieldData("Float16Vector", numRows, dim),
newBFloat16VectorFieldData("BFloat16Vector", numRows, dim),
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
newScalarFieldData(geometryFieldSchema, "Geometry", numRows),
}
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less bool data
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more bool data
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int8 data
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int8 data
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int16 data
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int16 data
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int32 data
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int32 data
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int64 data
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int64 data
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less float data
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more float data
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less double data
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less float vectors
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more float vectors
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less binary vectors
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more binary vectors
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less double data
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less float16 vectors
case2.insertMsg.FieldsData[9] = newFloat16VectorFieldData("Float16Vector", numRows/2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more float16 vectors
case2.insertMsg.FieldsData[9] = newFloat16VectorFieldData("Float16Vector", numRows*2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[9] = newFloat16VectorFieldData("Float16Vector", numRows, dim)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less bfloat16 vectors
case2.insertMsg.FieldsData[10] = newBFloat16VectorFieldData("BFloat16Vector", numRows/2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more bfloat16 vectors
case2.insertMsg.FieldsData[10] = newBFloat16VectorFieldData("BFloat16Vector", numRows*2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.insertMsg.FieldsData[10] = newBFloat16VectorFieldData("BFloat16Vector", numRows, dim)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
}
func TestInsertTask(t *testing.T) {
t.Run("test getChannels", func(t *testing.T) {
collectionID := UniqueID(0)
collectionName := "col-0"
channels := []pChan{"mock-chan-0", "mock-chan-1"}
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
globalMetaCache = cache
chMgr := NewMockChannelsMgr(t)
chMgr.EXPECT().getChannels(mock.Anything).Return(channels, nil)
it := insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: collectionName,
},
},
chMgr: chMgr,
}
err := it.setChannels()
assert.NoError(t, err)
resChannels := it.getChannels()
assert.ElementsMatch(t, channels, resChannels)
assert.ElementsMatch(t, channels, it.pChannels)
})
}
func TestMaxInsertSize(t *testing.T) {
t.Run("test MaxInsertSize", func(t *testing.T) {
paramtable.Init()
Params.Save(Params.QuotaConfig.MaxInsertSize.Key, "1")
defer Params.Reset(Params.QuotaConfig.MaxInsertSize.Key)
it := insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
DbName: "hooooooo",
CollectionName: "fooooo",
},
},
}
err := it.PreExecute(context.Background())
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrParameterTooLarge)
})
}
func TestInsertTask_KeepUserPK_WhenAllowInsertAutoIDTrue(t *testing.T) {
paramtable.Init()
// run auto-id path with field count check; allow user to pass PK
Params.Save(Params.ProxyCfg.SkipAutoIDCheck.Key, "false")
defer Params.Reset(Params.ProxyCfg.SkipAutoIDCheck.Key)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
Status: merr.Status(nil),
ID: 11198,
Count: 10,
}, nil)
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
idAllocator.Start()
defer idAllocator.Close()
assert.NoError(t, err)
nb := 5
userIDs := []int64{101, 102, 103, 104, 105}
collectionName := "TestInsertTask_KeepUserPK"
schema := &schemapb.CollectionSchema{
Name: collectionName,
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
},
Properties: []*commonpb.KeyValuePair{
{Key: common.AllowInsertAutoIDKey, Value: "true"},
},
}
pkFieldData := &schemapb.FieldData{
FieldName: "id",
FieldId: 100,
Type: schemapb.DataType_Int64,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{Data: userIDs},
},
},
},
}
task := insertTask{
ctx: context.Background(),
insertMsg: &BaseInsertTask{
InsertRequest: &msgpb.InsertRequest{
CollectionName: collectionName,
DbName: "test_db",
PartitionName: "_default",
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
Version: msgpb.InsertDataVersion_ColumnBased,
FieldsData: []*schemapb.FieldData{pkFieldData},
NumRows: uint64(nb),
},
},
idAllocator: idAllocator,
}
info := newSchemaInfo(schema)
cache := NewMockCache(t)
collectionID := UniqueID(0)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(info, nil)
cache.On("GetCollectionInfo",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&collectionInfo{schema: info}, nil)
cache.On("GetDatabaseInfo",
mock.Anything,
mock.Anything,
).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil)
globalMetaCache = cache
err = task.PreExecute(context.Background())
assert.NoError(t, err)
ids := task.result.IDs
if ids.GetIntId() == nil {
t.Fatalf("expected int IDs, got nil")
}
got := ids.GetIntId().GetData()
assert.Equal(t, userIDs, got)
}
func TestInsertTask_Function(t *testing.T) {
paramtable.Init()
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
return map[string]string{
"mock.apikey": "mock",
}
}
ts := embedding.CreateOpenAIEmbeddingServer()
defer ts.Close()
paramtable.Get().FunctionCfg.TextEmbeddingProviders.GetFunc = func() map[string]string {
return map[string]string{
"openai.url": ts.URL,
}
}
data := []*schemapb.FieldData{}
f := schemapb.FieldData{
Type: schemapb.DataType_VarChar,
FieldId: 101,
FieldName: "text",
IsDynamic: false,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: []string{"sentence", "sentence"},
},
},
},
},
}
data = append(data, &f)
collectionName := "TestInsertTask_function"
schema := &schemapb.CollectionSchema{
Name: collectionName,
Description: "TestInsertTask_function",
AutoID: true,
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{
FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{Key: "max_length", Value: "200"},
},
},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
IsFunctionOutput: true,
},
},
Functions: []*schemapb.FunctionSchema{
{
Name: "test_function",
Type: schemapb.FunctionType_TextEmbedding,
InputFieldIds: []int64{101},
InputFieldNames: []string{"text"},
OutputFieldIds: []int64{102},
OutputFieldNames: []string{"vector"},
Params: []*commonpb.KeyValuePair{
{Key: "provider", Value: "openai"},
{Key: "model_name", Value: "text-embedding-ada-002"},
{Key: "credential", Value: "mock"},
{Key: "dim", Value: "4"},
},
},
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
Status: merr.Status(nil),
ID: 11198,
Count: 10,
}, nil)
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
idAllocator.Start()
defer idAllocator.Close()
assert.NoError(t, err)
task := insertTask{
ctx: context.Background(),
insertMsg: &BaseInsertTask{
InsertRequest: &msgpb.InsertRequest{
CollectionName: collectionName,
DbName: "hooooooo",
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
Version: msgpb.InsertDataVersion_ColumnBased,
FieldsData: data,
NumRows: 2,
},
},
schema: schema,
idAllocator: idAllocator,
}
info := newSchemaInfo(schema)
cache := NewMockCache(t)
collectionID := UniqueID(0)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(collectionID, nil)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(info, nil)
cache.On("GetPartitionInfo",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(&partitionInfo{
name: "p1",
partitionID: 10,
createdTimestamp: 10001,
createdUtcTimestamp: 10002,
}, nil)
cache.On("GetCollectionInfo",
mock.Anything,
mock.Anything,
mock.Anything,
mock.Anything,
).Return(&collectionInfo{schema: info}, nil)
cache.On("GetDatabaseInfo",
mock.Anything,
mock.Anything,
).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil)
globalMetaCache = cache
err = task.PreExecute(ctx)
assert.NoError(t, err)
}
func TestInsertTaskForSchemaMismatch(t *testing.T) {
cache := globalMetaCache
defer func() { globalMetaCache = cache }()
mockCache := NewMockCache(t)
globalMetaCache = mockCache
ctx := context.Background()
t.Run("schema ts mismatch", func(t *testing.T) {
it := insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
DbName: "hooooooo",
CollectionName: "fooooo",
},
},
schemaTimestamp: 99,
}
mockCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil)
mockCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
updateTimestamp: 100,
}, nil)
mockCache.EXPECT().GetDatabaseInfo(mock.Anything, mock.Anything).Return(&databaseInfo{dbID: 0}, nil)
err := it.PreExecute(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrCollectionSchemaMismatch)
})
}
func TestInsertTask_Namespace(t *testing.T) {
paramtable.Init()
paramtable.Get().CommonCfg.EnableNamespace.SwapTempValue("true")
defer paramtable.Get().CommonCfg.EnableNamespace.SwapTempValue("false")
cache := NewMockCache(t)
globalMetaCache = cache
cache.On("GetDatabaseInfo",
mock.Anything,
mock.Anything,
).Return(&databaseInfo{properties: []*commonpb.KeyValuePair{}}, nil).Maybe()
cache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(0, nil).Maybe()
ctx := context.Background()
rc := mocks.NewMockRootCoordClient(t)
rc.EXPECT().AllocID(mock.Anything, mock.Anything).Return(&rootcoordpb.AllocIDResponse{
Status: merr.Status(nil),
ID: 11198,
Count: 10,
}, nil)
idAllocator, err := allocator.NewIDAllocator(ctx, rc, 0)
idAllocator.Start()
defer idAllocator.Close()
assert.NoError(t, err)
schemaWithNamespaceEnabled := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
{FieldID: 101, Name: common.NamespaceFieldName, DataType: schemapb.DataType_VarChar, IsPartitionKey: true, TypeParams: []*commonpb.KeyValuePair{
{Key: common.MaxLengthKey, Value: "100"},
}},
},
Properties: []*commonpb.KeyValuePair{
{Key: common.NamespaceEnabledKey, Value: "true"},
},
}
schemaWithNamespaceDisabled := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "id", DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true},
},
}
t.Run("test insert with namespace enabled", func(t *testing.T) {
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Unset()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Unset()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
schema: newSchemaInfo(schemaWithNamespaceEnabled),
}, nil).Maybe()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(newSchemaInfo(schemaWithNamespaceEnabled), nil).Maybe()
namespace := "test"
it := insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "test",
Namespace: &namespace,
NumRows: 100,
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
schema: schemaWithNamespaceEnabled,
idAllocator: idAllocator,
}
err := it.PreExecute(context.Background())
assert.NoError(t, err)
assert.Equal(t, int64(101), it.insertMsg.FieldsData[0].FieldId)
// namespace data is not set
it = insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "test",
NumRows: 100,
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
idAllocator: idAllocator,
}
err = it.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("test insert with namespace disabled", func(t *testing.T) {
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Unset()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Unset()
cache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&collectionInfo{
schema: newSchemaInfo(schemaWithNamespaceDisabled),
}, nil).Maybe()
cache.EXPECT().GetCollectionSchema(mock.Anything, mock.Anything, mock.Anything).Return(newSchemaInfo(schemaWithNamespaceDisabled), nil).Maybe()
cache.EXPECT().GetPartitionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&partitionInfo{
name: "p1",
partitionID: 10,
createdTimestamp: 10001,
createdUtcTimestamp: 10002,
}, nil).Maybe()
it := insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "test",
NumRows: 100,
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
schema: schemaWithNamespaceDisabled,
idAllocator: idAllocator,
}
err := it.PreExecute(context.Background())
assert.NoError(t, err)
// namespace data is set
namespace := "test"
it = insertTask{
ctx: context.Background(),
insertMsg: &msgstream.InsertMsg{
InsertRequest: &msgpb.InsertRequest{
CollectionName: "test",
Namespace: &namespace,
NumRows: 100,
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
idAllocator: idAllocator,
}
err = it.PreExecute(context.Background())
assert.Error(t, err)
})
}