milvus/internal/proxy/task_test.go
Hao Tan 67c4340565
feat: Geospatial Data Type and GIS Function Support for milvus server (#35990)
issue:https://github.com/milvus-io/milvus/issues/27576

# Main Goals
1. Create and describe collections with geospatial fields, enabling both
client and server to recognize and process geo fields.
2. Insert geospatial data as payload values in the insert binlog, and
print the values for verification.
3. Load segments containing geospatial data into memory.
4. Ensure query outputs can display geospatial data.
5. Support filtering on GIS functions for geospatial columns.

# Solution
1. **Add Type**: Modify the Milvus core by adding a Geospatial type in
both the C++ and Go code layers, defining the Geospatial data structure
and the corresponding interfaces.
2. **Dependency Libraries**: Introduce necessary geospatial data
processing libraries. In the C++ source code, use Conan package
management to include the GDAL library. In the Go source code, add the
go-geom library to the go.mod file.
3. **Protocol Interface**: Revise the Milvus protocol to provide
mechanisms for Geospatial message serialization and deserialization.
4. **Data Pipeline**: Facilitate interaction between the client and
proxy using the WKT format for geospatial data. The proxy will convert
all data into WKB format for downstream processing, providing column
data interfaces, segment encapsulation, segment loading, payload
writing, and cache block management.
5. **Query Operators**: Implement simple display and support for filter
queries. Initially, focus on filtering based on spatial relationships
for a single column of geospatial literal values, providing parsing and
execution for query expressions.
6. **Client Modification**: Enable the client to handle user input for
geospatial data and facilitate end-to-end testing.Check the modification
in pymilvus.

---------

Signed-off-by: tasty-gumi <1021989072@qq.com>
2024-10-31 20:58:20 +08:00

4052 lines
122 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"math/rand"
"strconv"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"google.golang.org/grpc"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querycoordv2/meta"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/mq/msgstream"
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/util/funcutil"
"github.com/milvus-io/milvus/pkg/util/merr"
"github.com/milvus-io/milvus/pkg/util/metric"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/testutils"
"github.com/milvus-io/milvus/pkg/util/timerecord"
"github.com/milvus-io/milvus/pkg/util/typeutil"
"github.com/milvus-io/milvus/pkg/util/uniquegenerator"
)
// TODO(dragondriver): add more test cases
const (
maxTestStringLen = 100
testBoolField = "bool"
testInt32Field = "int32"
testInt64Field = "int64"
testFloatField = "float"
testDoubleField = "double"
testVarCharField = "varChar"
testFloatVecField = "fvec"
testBinaryVecField = "bvec"
testFloat16VecField = "f16vec"
testBFloat16VecField = "bf16vec"
testGeometryField = "geometry"
testVecDim = 128
testMaxVarCharLength = 100
)
func genCollectionSchema(collectionName string) *schemapb.CollectionSchema {
return constructCollectionSchemaWithAllType(
testBoolField,
testInt32Field,
testInt64Field,
testFloatField,
testDoubleField,
testFloatVecField,
testBinaryVecField,
testFloat16VecField,
testBFloat16VecField,
testGeometryField,
testVecDim,
collectionName)
}
func constructCollectionSchema(
int64Field, floatVecField string,
dim int,
collectionName string,
) *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 100,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 101,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
func constructCollectionSchemaEnableDynamicSchema(
int64Field, floatVecField string,
dim int,
collectionName string,
) *schemapb.CollectionSchema {
pk := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: true,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
pk,
fVec,
},
}
}
func ConstructCollectionSchemaWithPartitionKey(collectionName string, fieldName2DataType map[string]schemapb.DataType, primaryFieldName string, partitionKeyFieldName string, autoID bool) *schemapb.CollectionSchema {
schema := constructCollectionSchemaByDataType(collectionName, fieldName2DataType, primaryFieldName, autoID)
for _, field := range schema.Fields {
if field.Name == partitionKeyFieldName {
field.IsPartitionKey = true
}
}
return schema
}
func constructCollectionSchemaByDataType(collectionName string, fieldName2DataType map[string]schemapb.DataType, primaryFieldName string, autoID bool) *schemapb.CollectionSchema {
fieldsSchema := make([]*schemapb.FieldSchema, 0)
idx := int64(100)
for fieldName, dataType := range fieldName2DataType {
fieldSchema := &schemapb.FieldSchema{
FieldID: idx,
Name: fieldName,
DataType: dataType,
}
idx++
if typeutil.IsVectorType(dataType) {
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(testVecDim),
},
}
}
if dataType == schemapb.DataType_VarChar {
fieldSchema.TypeParams = []*commonpb.KeyValuePair{
{
Key: common.MaxLengthKey,
Value: strconv.Itoa(testMaxVarCharLength),
},
}
}
if fieldName == primaryFieldName {
fieldSchema.IsPrimaryKey = true
fieldSchema.AutoID = autoID
}
fieldsSchema = append(fieldsSchema, fieldSchema)
}
return &schemapb.CollectionSchema{
Name: collectionName,
Fields: fieldsSchema,
}
}
func constructCollectionSchemaWithAllType(
boolField, int32Field, int64Field, floatField, doubleField string,
floatVecField, binaryVecField, float16VecField, bfloat16VecField string,
geometryField string,
dim int,
collectionName string,
) *schemapb.CollectionSchema {
b := &schemapb.FieldSchema{
FieldID: 0,
Name: boolField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Bool,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
i32 := &schemapb.FieldSchema{
FieldID: 0,
Name: int32Field,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Int32,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
i64 := &schemapb.FieldSchema{
FieldID: 0,
Name: int64Field,
IsPrimaryKey: true,
Description: "",
DataType: schemapb.DataType_Int64,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
f := &schemapb.FieldSchema{
FieldID: 0,
Name: floatField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Float,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
d := &schemapb.FieldSchema{
FieldID: 0,
Name: doubleField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Double,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
fVec := &schemapb.FieldSchema{
FieldID: 0,
Name: floatVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
bVec := &schemapb.FieldSchema{
FieldID: 0,
Name: binaryVecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_BinaryVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
f16Vec := &schemapb.FieldSchema{
FieldID: 0,
Name: float16VecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Float16Vector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
bf16Vec := &schemapb.FieldSchema{
FieldID: 0,
Name: bfloat16VecField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_BFloat16Vector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(dim),
},
},
IndexParams: nil,
AutoID: false,
}
g := &schemapb.FieldSchema{
FieldID: 0,
Name: geometryField,
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_Geometry,
TypeParams: nil,
IndexParams: nil,
AutoID: false,
}
if enableMultipleVectorFields {
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
b,
i32,
i64,
f,
d,
fVec,
bVec,
f16Vec,
bf16Vec,
g,
},
}
}
return &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
b,
i32,
i64,
f,
d,
fVec,
// bVec,
g,
},
}
}
func constructPlaceholderGroup(
nq, dim int,
) *commonpb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
if err != nil {
panic(err)
}
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &commonpb.PlaceholderGroup{
Placeholders: []*commonpb.PlaceholderValue{
{
Tag: "$0",
Type: commonpb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
func constructSearchRequest(
dbName, collectionName string,
expr string,
floatVecField string,
nq, dim, nprobe, topk, roundDecimal int,
) *milvuspb.SearchRequest {
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
if err != nil {
panic(err)
}
plg := constructPlaceholderGroup(nq, dim)
plgBs, err := proto.Marshal(plg)
if err != nil {
panic(err)
}
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: []*commonpb.KeyValuePair{
{
Key: common.MetricTypeKey,
Value: metric.L2,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: strconv.Itoa(topk),
},
{
Key: RoundDecimalKey,
Value: strconv.Itoa(roundDecimal),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
func TestTranslateOutputFields(t *testing.T) {
const (
idFieldName = "id"
tsFieldName = "timestamp"
floatVectorFieldName = "float_vector"
binaryVectorFieldName = "binary_vector"
float16VectorFieldName = "float16_vector"
bfloat16VectorFieldName = "bfloat16_vector"
)
var outputFields []string
var userOutputFields []string
var userDynamicFields []string
var err error
collSchema := &schemapb.CollectionSchema{
Name: "TestTranslateOutputFields",
Description: "TestTranslateOutputFields",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{Name: idFieldName, FieldID: 0, DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{Name: tsFieldName, FieldID: 1, DataType: schemapb.DataType_Int64},
{Name: floatVectorFieldName, FieldID: 100, DataType: schemapb.DataType_FloatVector},
{Name: binaryVectorFieldName, FieldID: 101, DataType: schemapb.DataType_BinaryVector},
{Name: float16VectorFieldName, FieldID: 102, DataType: schemapb.DataType_Float16Vector},
{Name: bfloat16VectorFieldName, FieldID: 103, DataType: schemapb.DataType_BFloat16Vector},
},
}
schema := newSchemaInfo(collSchema)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{}, outputFields)
assert.ElementsMatch(t, []string{}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, tsFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, tsFieldName, floatVectorFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*"}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{" * "}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, false)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
//=========================================================================
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, tsFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, tsFieldName, floatVectorFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*"}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*", tsFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"*", floatVectorFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, outputFields)
assert.ElementsMatch(t, []string{idFieldName, tsFieldName, floatVectorFieldName, binaryVectorFieldName, float16VectorFieldName, bfloat16VectorFieldName}, userOutputFields)
assert.ElementsMatch(t, []string{}, userDynamicFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"A"}, schema, true)
assert.Error(t, err)
t.Run("enable dynamic schema", func(t *testing.T) {
collSchema := &schemapb.CollectionSchema{
Name: "TestTranslateOutputFields",
Description: "TestTranslateOutputFields",
AutoID: false,
EnableDynamicField: true,
Fields: []*schemapb.FieldSchema{
{Name: idFieldName, FieldID: 1, DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{Name: tsFieldName, FieldID: 2, DataType: schemapb.DataType_Int64},
{Name: floatVectorFieldName, FieldID: 100, DataType: schemapb.DataType_FloatVector},
{Name: binaryVectorFieldName, FieldID: 101, DataType: schemapb.DataType_BinaryVector},
{Name: common.MetaFieldName, FieldID: 102, DataType: schemapb.DataType_JSON, IsDynamic: true},
},
}
schema := newSchemaInfo(collSchema)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{"A", idFieldName}, schema, true)
assert.Equal(t, nil, err)
assert.ElementsMatch(t, []string{common.MetaFieldName, idFieldName}, outputFields)
assert.ElementsMatch(t, []string{"A", idFieldName}, userOutputFields)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, floatVectorFieldName, "$meta[\"A\"]"}, schema, true)
assert.Error(t, err)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, floatVectorFieldName, "$meta[]"}, schema, true)
assert.Error(t, err)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, floatVectorFieldName, "$meta[\"\"]"}, schema, true)
assert.Error(t, err)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, floatVectorFieldName, "$meta["}, schema, true)
assert.Error(t, err)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, floatVectorFieldName, "[]"}, schema, true)
assert.Error(t, err)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, floatVectorFieldName, "A > 1"}, schema, true)
assert.Error(t, err)
outputFields, userOutputFields, userDynamicFields, err = translateOutputFields([]string{idFieldName, floatVectorFieldName, ""}, schema, true)
assert.Error(t, err)
})
}
func TestCreateCollectionTask(t *testing.T) {
rc := NewRootCoordMock()
ctx := context.Background()
shardsNum := common.DefaultShardsNum
prefix := "TestCreateCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
varCharField := "varChar"
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type[int64Field] = schemapb.DataType_Int64
fieldName2Type[varCharField] = schemapb.DataType_VarChar
fieldName2Type[floatVecField] = schemapb.DataType_FloatVector
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, int64Field, false)
var marshaledSchema []byte
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
t.Run("on enqueue", func(t *testing.T) {
err := task.OnEnqueue()
assert.NoError(t, err)
assert.Equal(t, commonpb.MsgType_CreateCollection, task.Type())
})
t.Run("ctx", func(t *testing.T) {
traceCtx := task.TraceCtx()
assert.NotNil(t, traceCtx)
})
t.Run("id", func(t *testing.T) {
id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
task.SetID(id)
assert.Equal(t, id, task.ID())
})
t.Run("name", func(t *testing.T) {
assert.Equal(t, CreateCollectionTaskName, task.Name())
})
t.Run("ts", func(t *testing.T) {
ts := Timestamp(time.Now().UnixNano())
task.SetTs(ts)
assert.Equal(t, ts, task.BeginTs())
assert.Equal(t, ts, task.EndTs())
})
t.Run("process task", func(t *testing.T) {
var err error
err = task.PreExecute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
// recreate -> fail
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
err = task.PostExecute(ctx)
assert.NoError(t, err)
})
t.Run("PreExecute", func(t *testing.T) {
var err error
err = task.PreExecute(ctx)
assert.NoError(t, err)
Params.Save(Params.ProxyCfg.MustUsePartitionKey.Key, "true")
err = task.PreExecute(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
Params.Reset(Params.ProxyCfg.MustUsePartitionKey.Key)
task.Schema = []byte{0x1, 0x2, 0x3, 0x4}
err = task.PreExecute(ctx)
assert.Error(t, err)
task.Schema = marshaledSchema
task.ShardsNum = Params.ProxyCfg.MaxShardNum.GetAsInt32() + 1
err = task.PreExecute(ctx)
assert.Error(t, err)
task.ShardsNum = shardsNum
reqBackup := proto.Clone(task.CreateCollectionRequest).(*milvuspb.CreateCollectionRequest)
schemaBackup := proto.Clone(schema).(*schemapb.CollectionSchema)
schemaWithTooManyFields := &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: make([]*schemapb.FieldSchema, Params.ProxyCfg.MaxFieldNum.GetAsInt32()+1),
}
marshaledSchemaWithTooManyFields, err := proto.Marshal(schemaWithTooManyFields)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = marshaledSchemaWithTooManyFields
err = task.PreExecute(ctx)
assert.Error(t, err)
// too many vector fields
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
schema.Fields = append(schema.Fields, schema.Fields[0])
for i := 0; i < Params.ProxyCfg.MaxVectorFieldNum.GetAsInt(); i++ {
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 101,
Name: floatVecField + "_" + strconv.Itoa(i),
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(testVecDim),
},
},
IndexParams: nil,
AutoID: false,
})
}
tooManyVectorFieldsSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = tooManyVectorFieldsSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// without vector field
schema = &schemapb.CollectionSchema{
Name: collectionName,
Description: "",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "id",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
},
},
}
noVectorSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = noVectorSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CreateCollectionRequest = reqBackup
// validateCollectionName
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
schema.Name = " " // empty
emptyNameSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = emptyNameSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema.Name = prefix
for i := 0; i < Params.ProxyCfg.MaxNameLength.GetAsInt(); i++ {
schema.Name += strconv.Itoa(i % 10)
}
tooLongNameSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = tooLongNameSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema.Name = "$" // invalid first char
invalidFirstCharSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = invalidFirstCharSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// validateDuplicatedFieldName
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
schema.Fields = append(schema.Fields, schema.Fields[0])
duplicatedFieldsSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = duplicatedFieldsSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// validatePrimaryKey
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
schema.Fields[idx].IsPrimaryKey = false
}
noPrimaryFieldsSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = noPrimaryFieldsSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// validateFieldName
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
schema.Fields[idx].Name = "$"
}
invalidFieldNameSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = invalidFieldNameSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// validateMaxLengthPerRow
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_VarChar {
schema.Fields[idx].TypeParams = nil
}
}
noTypeParamsSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = noTypeParamsSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// ValidateVectorField
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_FloatVector ||
schema.Fields[idx].DataType == schemapb.DataType_BinaryVector {
schema.Fields[idx].TypeParams = nil
}
}
noDimSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = noDimSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// ValidateVectorField
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for _, field := range schema.Fields {
field.TypeParams = append(field.TypeParams, &commonpb.KeyValuePair{
Key: common.FieldSkipLoadKey,
Value: "true",
})
}
// Validate default load list
skipLoadSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = skipLoadSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_FloatVector ||
schema.Fields[idx].DataType == schemapb.DataType_BinaryVector {
schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "not int",
},
}
}
}
dimNotIntSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = dimNotIntSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
for idx := range schema.Fields {
if schema.Fields[idx].DataType == schemapb.DataType_FloatVector ||
schema.Fields[idx].DataType == schemapb.DataType_BinaryVector {
schema.Fields[idx].TypeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(Params.ProxyCfg.MaxDimension.GetAsInt() + 1),
},
}
}
}
tooLargeDimSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = tooLargeDimSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
schema.Fields[1].DataType = schemapb.DataType_BinaryVector
schema.Fields[1].TypeParams = []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(Params.ProxyCfg.MaxDimension.GetAsInt() + 1),
},
}
binaryTooLargeDimSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = binaryTooLargeDimSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
schema = proto.Clone(schemaBackup).(*schemapb.CollectionSchema)
schema.Fields = append(schema.Fields, &schemapb.FieldSchema{
FieldID: 0,
Name: "second_vector",
IsPrimaryKey: false,
Description: "",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(128),
},
},
IndexParams: nil,
AutoID: false,
})
twoVecFieldsSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task.CreateCollectionRequest.Schema = twoVecFieldsSchema
err = task.PreExecute(ctx)
if enableMultipleVectorFields {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
t.Run("specify dynamic field", func(t *testing.T) {
dynamicField := &schemapb.FieldSchema{
Name: "json",
IsDynamic: true,
}
var marshaledSchema []byte
schema2 := &schemapb.CollectionSchema{
Name: collectionName,
Fields: append(schema.Fields, dynamicField),
}
marshaledSchema, err := proto.Marshal(schema2)
assert.NoError(t, err)
task2 := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = task2.OnEnqueue()
assert.NoError(t, err)
err = task2.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestHasCollectionTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
prefix := "TestHasCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
shardsNum := int32(2)
int64Field := "int64"
floatVecField := "fvec"
dim := 128
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
}
// CreateCollection
task := &hasCollectionTask{
Condition: NewTaskCondition(ctx),
HasCollectionRequest: &milvuspb.HasCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_HasCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_HasCollection, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
// missing collectionID in globalMetaCache
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, false, task.result.Value)
// createIsoCollection in RootCood and fill GlobalMetaCache
rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
// success to drop collection
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, true, task.result.Value)
// illegal name
task.CollectionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CollectionName = collectionName
// invalidate collection cache, trigger rootcoord rpc
globalMetaCache.RemoveCollection(ctx, dbName, collectionName)
// rc return collection not found error
rc.describeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) {
return nil, merr.WrapErrCollectionNotFoundWithDB(dbName, collectionName)
}
err = task.PreExecute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.False(t, task.result.GetValue())
// rootcoord failed to get response
rc.updateState(commonpb.StateCode_Abnormal)
err = task.PreExecute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.Error(t, err)
}
func TestDescribeCollectionTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
// CreateCollection
task := &describeCollectionTask{
Condition: NewTaskCondition(ctx),
DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_DescribeCollection, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
// missing collectionID in globalMetaCache
err := task.Execute(ctx)
assert.NoError(t, err)
// illegal name
task.CollectionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
// describe collection with id
task.CollectionID = 1
task.CollectionName = ""
err = task.PreExecute(ctx)
assert.NoError(t, err)
rc.Close()
task.CollectionID = 0
task.CollectionName = collectionName
err = task.PreExecute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.GetStatus().GetErrorCode())
}
func TestDescribeCollectionTask_ShardsNum1(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
shardsNum := common.DefaultShardsNum
int64Field := "int64"
floatVecField := "fvec"
dim := 128
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
}
rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
// CreateCollection
task := &describeCollectionTask{
Condition: NewTaskCondition(ctx),
DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
err = task.PreExecute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode())
assert.Equal(t, shardsNum, task.result.ShardsNum)
assert.Equal(t, collectionName, task.result.GetCollectionName())
}
func TestDescribeCollectionTask_EnableDynamicSchema(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
shardsNum := common.DefaultShardsNum
int64Field := "int64"
floatVecField := "fvec"
dim := 128
schema := constructCollectionSchemaEnableDynamicSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
}
rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, dbName, collectionName)
// CreateCollection
task := &describeCollectionTask{
Condition: NewTaskCondition(ctx),
DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
err = task.PreExecute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode())
assert.Equal(t, shardsNum, task.result.ShardsNum)
assert.Equal(t, collectionName, task.result.GetCollectionName())
assert.Equal(t, 2, len(task.result.Schema.Fields))
}
func TestDescribeCollectionTask_ShardsNum2(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
prefix := "TestDescribeCollectionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := "int64"
floatVecField := "fvec"
dim := 128
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
}
rc.CreateCollection(ctx, createColReq)
globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
// CreateCollection
task := &describeCollectionTask{
Condition: NewTaskCondition(ctx),
DescribeCollectionRequest: &milvuspb.DescribeCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DescribeCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
task.PreExecute(ctx)
// missing collectionID in globalMetaCache
err = task.Execute(ctx)
assert.NoError(t, err)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode())
assert.Equal(t, common.DefaultShardsNum, task.result.ShardsNum)
assert.Equal(t, collectionName, task.result.GetCollectionName())
rc.Close()
}
func TestCreatePartitionTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
ctx := context.Background()
prefix := "TestCreatePartitionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
task := &createPartitionTask{
Condition: NewTaskCondition(ctx),
CreatePartitionRequest: &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreatePartition,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_CreatePartition, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
err := task.Execute(ctx)
assert.Error(t, err)
task.CollectionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CollectionName = collectionName
task.PartitionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
}
func TestDropPartitionTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
ctx := context.Background()
prefix := "TestDropPartitionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
qc := getQueryCoordClient()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []int64{},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: merr.Success(),
}, nil)
mockCache := NewMockCache(t)
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
task := &dropPartitionTask{
Condition: NewTaskCondition(ctx),
DropPartitionRequest: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropPartition,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
},
ctx: ctx,
rootCoord: rc,
queryCoord: qc,
result: nil,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_DropPartition, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
err := task.Execute(ctx)
assert.Error(t, err)
task.CollectionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CollectionName = collectionName
task.PartitionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
t.Run("get collectionID error", func(t *testing.T) {
mockCache := NewMockCache(t)
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), errors.New("error"))
mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
task.PartitionName = "partition1"
err = task.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("partition not exist", func(t *testing.T) {
task.PartitionName = "partition2"
mockCache := NewMockCache(t)
mockCache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), merr.WrapErrPartitionNotFound(partitionName))
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
err = task.PreExecute(ctx)
assert.NoError(t, err)
})
t.Run("get partition error", func(t *testing.T) {
task.PartitionName = "partition3"
mockCache := NewMockCache(t)
mockCache.On("GetPartitionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), errors.New("error"))
mockCache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(1), nil)
mockCache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{}), nil)
globalMetaCache = mockCache
err = task.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestHasPartitionTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
ctx := context.Background()
prefix := "TestHasPartitionTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
task := &hasPartitionTask{
Condition: NewTaskCondition(ctx),
HasPartitionRequest: &milvuspb.HasPartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_HasPartition,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_HasPartition, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
err := task.Execute(ctx)
assert.Error(t, err)
task.CollectionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CollectionName = collectionName
task.PartitionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
}
func TestShowPartitionsTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
ctx := context.Background()
prefix := "TestShowPartitionsTask"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
task := &showPartitionsTask{
Condition: NewTaskCondition(ctx),
ShowPartitionsRequest: &milvuspb.ShowPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_ShowPartitions,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
PartitionNames: []string{partitionName},
Type: milvuspb.ShowType_All,
},
ctx: ctx,
rootCoord: rc,
result: nil,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_ShowPartitions, task.Type())
assert.Equal(t, UniqueID(100), task.ID())
assert.Equal(t, Timestamp(100), task.BeginTs())
assert.Equal(t, Timestamp(100), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.GetBase().GetSourceID())
err := task.Execute(ctx)
assert.Error(t, err)
task.CollectionName = "#0xc0de"
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CollectionName = collectionName
task.ShowPartitionsRequest.Type = milvuspb.ShowType_InMemory
task.PartitionNames = []string{"#0xc0de"}
err = task.PreExecute(ctx)
assert.Error(t, err)
task.CollectionName = collectionName
task.PartitionNames = []string{partitionName}
task.ShowPartitionsRequest.Type = milvuspb.ShowType_InMemory
err = task.Execute(ctx)
assert.Error(t, err)
}
func TestTask_Int64PrimaryKey(t *testing.T) {
var err error
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
err = InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testFloatVecField: schemapb.DataType_FloatVector,
}
if enableMultipleVectorFields {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
nb := 10
t.Run("create collection", func(t *testing.T) {
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreatePartition,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
})
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
interval := time.Millisecond * 10
tso := newMockTsoAllocator()
ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso)
_ = ticker.start()
defer ticker.close()
idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID())
assert.NoError(t, err)
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("insert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb)
task := &insertTask{
insertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
NumRows: uint64(nb),
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
Condition: NewTaskCondition(ctx),
ctx: ctx,
result: &milvuspb.MutationResult{
Status: merr.Success(),
IDs: nil,
SuccIndex: nil,
ErrIndex: nil,
Acknowledged: false,
InsertCnt: 0,
DeleteCnt: 0,
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
}
for fieldName, dataType := range fieldName2Types {
task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb))
}
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
})
t.Run("simple delete", func(t *testing.T) {
task := &deleteTask{
Condition: NewTaskCondition(ctx),
req: &milvuspb.DeleteRequest{
CollectionName: collectionName,
PartitionName: partitionName,
Expr: "int64 in [0, 1]",
},
idAllocator: idAllocator,
ctx: ctx,
primaryKeys: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{0, 1}}},
},
chMgr: chMgr,
chTicker: ticker,
collectionID: collectionID,
vChannels: []string{"test-ch"},
}
assert.NoError(t, task.OnEnqueue())
assert.NotNil(t, task.TraceCtx())
id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
task.SetID(id)
assert.Equal(t, id, task.ID())
assert.Equal(t, commonpb.MsgType_Delete, task.Type())
ts := Timestamp(time.Now().UnixNano())
task.SetTs(ts)
assert.Equal(t, ts, task.BeginTs())
assert.Equal(t, ts, task.EndTs())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
})
}
func TestIndexType(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
ctx := context.Background()
shardsNum := int32(2)
prefix := "TestTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testFloatVecField: schemapb.DataType_FloatVector,
}
t.Run("invalid type param", func(t *testing.T) {
paramtable.Init()
Params.Save(Params.AutoIndexConfig.Enable.Key, "true")
defer Params.Reset(Params.AutoIndexConfig.Enable.Key)
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
for _, field := range schema.Fields {
dataType := field.GetDataType()
if typeutil.IsVectorType(dataType) {
field.IndexParams = append(field.IndexParams, &commonpb.KeyValuePair{
Key: common.MmapEnabledKey,
Value: "true",
})
break
}
}
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
}
assert.NoError(t, createColT.OnEnqueue())
assert.Error(t, createColT.PreExecute(ctx))
})
}
func TestTask_VarCharPrimaryKey(t *testing.T) {
var err error
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
err = InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
partitionName := prefix + funcutil.GenRandomStr()
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
testInt64Field: schemapb.DataType_Int64,
testFloatField: schemapb.DataType_Float,
testDoubleField: schemapb.DataType_Double,
testVarCharField: schemapb.DataType_VarChar,
testFloatVecField: schemapb.DataType_FloatVector,
}
if enableMultipleVectorFields {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
nb := 10
t.Run("create collection", func(t *testing.T) {
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testVarCharField, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
_, _ = rc.CreatePartition(ctx, &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreatePartition,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
})
})
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
interval := time.Millisecond * 10
tso := newMockTsoAllocator()
ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso)
_ = ticker.start()
defer ticker.close()
idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID())
assert.NoError(t, err)
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
t.Run("insert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb)
task := &insertTask{
insertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
NumRows: uint64(nb),
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
Condition: NewTaskCondition(ctx),
ctx: ctx,
result: &milvuspb.MutationResult{
Status: merr.Success(),
IDs: nil,
SuccIndex: nil,
ErrIndex: nil,
Acknowledged: false,
InsertCnt: 0,
DeleteCnt: 0,
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb))
fieldID++
}
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
})
t.Run("upsert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb)
task := &upsertTask{
upsertMsg: &msgstream.UpsertMsg{
InsertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
NumRows: uint64(nb),
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
DeleteMsg: &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
DeleteRequest: &msgpb.DeleteRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Delete,
MsgID: 0,
Timestamp: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
},
},
},
Condition: NewTaskCondition(ctx),
req: &milvuspb.UpsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: partitionName,
HashKeys: hash,
NumRows: uint32(nb),
},
ctx: ctx,
result: &milvuspb.MutationResult{
Status: merr.Success(),
IDs: nil,
SuccIndex: nil,
ErrIndex: nil,
Acknowledged: false,
InsertCnt: 0,
DeleteCnt: 0,
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
task.req.FieldsData = append(task.req.FieldsData, generateFieldData(dataType, fieldName, nb))
fieldID++
}
assert.NoError(t, task.OnEnqueue())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
})
t.Run("simple delete", func(t *testing.T) {
task := &deleteTask{
Condition: NewTaskCondition(ctx),
req: &milvuspb.DeleteRequest{
CollectionName: collectionName,
PartitionName: partitionName,
Expr: "varChar in [\"milvus\", \"test\"]",
},
idAllocator: idAllocator,
ctx: ctx,
chMgr: chMgr,
chTicker: ticker,
vChannels: []string{"test-channel"},
primaryKeys: &schemapb.IDs{
IdField: &schemapb.IDs_StrId{StrId: &schemapb.StringArray{Data: []string{"milvus", "test"}}},
},
collectionID: collectionID,
}
assert.NoError(t, task.OnEnqueue())
assert.NotNil(t, task.TraceCtx())
id := UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt())
task.SetID(id)
assert.Equal(t, id, task.ID())
assert.Equal(t, commonpb.MsgType_Delete, task.Type())
ts := Timestamp(time.Now().UnixNano())
task.SetTs(ts)
assert.Equal(t, ts, task.BeginTs())
assert.Equal(t, ts, task.EndTs())
assert.NoError(t, task.PreExecute(ctx))
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
})
}
func Test_createIndexTask_getIndexedFieldAndFunction(t *testing.T) {
collectionName := "test"
fieldName := "test"
cit := &createIndexTask{
req: &milvuspb.CreateIndexRequest{
CollectionName: collectionName,
FieldName: fieldName,
},
}
idField := &schemapb.FieldSchema{
FieldID: 100,
Name: "id",
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
}
vectorField := &schemapb.FieldSchema{
FieldID: 101,
Name: fieldName,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
}
t.Run("normal", func(t *testing.T) {
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
idField,
vectorField,
},
}), nil)
globalMetaCache = cache
err := cit.getIndexedFieldAndFunction(context.Background())
assert.NoError(t, err)
assert.Equal(t, fieldName, cit.fieldSchema.GetName())
})
t.Run("schema not found", func(t *testing.T) {
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(nil, errors.New("mock"))
globalMetaCache = cache
err := cit.getIndexedFieldAndFunction(context.Background())
assert.Error(t, err)
})
t.Run("field not found", func(t *testing.T) {
otherField := typeutil.Clone(vectorField)
otherField.Name = otherField.Name + "_other"
cache := NewMockCache(t)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
idField,
otherField,
},
}), nil)
globalMetaCache = cache
err := cit.getIndexedFieldAndFunction(context.Background())
assert.Error(t, err)
})
}
func Test_fillDimension(t *testing.T) {
t.Run("scalar", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int64,
}
assert.NoError(t, fillDimension(f, nil))
})
t.Run("no dim in schema", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
}
assert.Error(t, fillDimension(f, nil))
})
t.Run("dimension mismatch", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "128",
},
},
}
assert.Error(t, fillDimension(f, map[string]string{common.DimKey: "8"}))
})
t.Run("normal", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "128",
},
},
}
m := map[string]string{}
assert.NoError(t, fillDimension(f, m))
assert.Equal(t, "128", m[common.DimKey])
})
}
func Test_checkTrain(t *testing.T) {
t.Run("normal", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "128",
},
},
}
m := map[string]string{
common.IndexTypeKey: "IVF_FLAT",
"nlist": "1024",
common.MetricTypeKey: "L2",
}
assert.NoError(t, checkTrain(f, m))
})
t.Run("scalar", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_Int64,
}
m := map[string]string{
common.IndexTypeKey: "scalar",
}
assert.Error(t, checkTrain(f, m))
})
t.Run("dimension mismatch", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "128",
},
},
}
m := map[string]string{
common.IndexTypeKey: "IVF_FLAT",
"nlist": "1024",
common.MetricTypeKey: "L2",
common.DimKey: "8",
}
assert.Error(t, checkTrain(f, m))
})
t.Run("invalid params", func(t *testing.T) {
f := &schemapb.FieldSchema{
DataType: schemapb.DataType_FloatVector,
IndexParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: "128",
},
},
}
m := map[string]string{
common.IndexTypeKey: "IVF_FLAT",
common.MetricTypeKey: "L2",
}
assert.Error(t, checkTrain(f, m))
})
}
func Test_createIndexTask_PreExecute(t *testing.T) {
collectionName := "test"
fieldName := "test"
cit := &createIndexTask{
req: &milvuspb.CreateIndexRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_CreateIndex,
},
CollectionName: collectionName,
FieldName: fieldName,
},
}
t.Run("normal", func(t *testing.T) {
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(100), nil)
cache.On("GetCollectionSchema",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(newSchemaInfo(&schemapb.CollectionSchema{
Fields: []*schemapb.FieldSchema{
{
FieldID: 100,
Name: fieldName,
IsPrimaryKey: false,
DataType: schemapb.DataType_FloatVector,
TypeParams: nil,
IndexParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "128",
},
},
AutoID: false,
},
},
}), nil)
globalMetaCache = cache
cit.req.ExtraParams = []*commonpb.KeyValuePair{
{
Key: common.IndexTypeKey,
Value: "IVF_FLAT",
},
{
Key: "nlist",
Value: "1024",
},
{
Key: common.MetricTypeKey,
Value: "L2",
},
}
assert.NoError(t, cit.PreExecute(context.Background()))
})
t.Run("collection not found", func(t *testing.T) {
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(0), errors.New("mock"))
globalMetaCache = cache
assert.Error(t, cit.PreExecute(context.Background()))
})
t.Run("index name length exceed 255", func(t *testing.T) {
cache := NewMockCache(t)
cache.On("GetCollectionID",
mock.Anything, // context.Context
mock.AnythingOfType("string"),
mock.AnythingOfType("string"),
).Return(UniqueID(100), nil)
globalMetaCache = cache
for i := 0; i < 256; i++ {
cit.req.IndexName += "a"
}
err := cit.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("index name start with number", func(t *testing.T) {
cit.req.IndexName = "12a"
err := cit.PreExecute(context.Background())
assert.Error(t, err)
})
t.Run("index name include special characters", func(t *testing.T) {
cit.req.IndexName = "ac#1"
err := cit.PreExecute(context.Background())
assert.Error(t, err)
})
}
func Test_dropCollectionTask_PreExecute(t *testing.T) {
dct := &dropCollectionTask{DropCollectionRequest: &milvuspb.DropCollectionRequest{
Base: &commonpb.MsgBase{},
CollectionName: "0xffff", // invalid
}}
ctx := context.Background()
err := dct.PreExecute(ctx)
assert.Error(t, err)
dct.DropCollectionRequest.CollectionName = "valid"
err = dct.PreExecute(ctx)
assert.NoError(t, err)
}
func Test_dropCollectionTask_Execute(t *testing.T) {
mockRC := mocks.NewMockRootCoordClient(t)
mockRC.On("DropCollection",
mock.Anything, // context.Context
mock.Anything, // *milvuspb.DropCollectionRequest
mock.Anything,
).Return(&commonpb.Status{}, func(ctx context.Context, request *milvuspb.DropCollectionRequest, opts ...grpc.CallOption) error {
switch request.GetCollectionName() {
case "c1":
return errors.New("error mock DropCollection")
case "c2":
return merr.WrapErrCollectionNotFound("mock")
default:
return nil
}
})
ctx := context.Background()
dct := &dropCollectionTask{rootCoord: mockRC, DropCollectionRequest: &milvuspb.DropCollectionRequest{CollectionName: "normal"}}
err := dct.Execute(ctx)
assert.NoError(t, err)
dct.DropCollectionRequest.CollectionName = "c1"
err = dct.Execute(ctx)
assert.Error(t, err)
dct.DropCollectionRequest.CollectionName = "c2"
err = dct.Execute(ctx)
assert.Error(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, dct.result.GetErrorCode())
}
func Test_dropCollectionTask_PostExecute(t *testing.T) {
dct := &dropCollectionTask{}
assert.NoError(t, dct.PostExecute(context.Background()))
}
func Test_loadCollectionTask_Execute(t *testing.T) {
rc := newMockRootCoord()
dc := NewDataCoordMock()
qc := getQueryCoordClient()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []int64{},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: merr.Success(),
}, nil)
dbName := funcutil.GenRandomStr()
collectionName := funcutil.GenRandomStr()
collectionID := UniqueID(1)
// fieldName := funcutil.GenRandomStr()
indexName := funcutil.GenRandomStr()
ctx := context.Background()
indexID := int64(1000)
shardMgr := newShardClientMgr()
// failed to get collection id.
_ = InitMetaCache(ctx, rc, qc, shardMgr)
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
Schema: newTestSchema(),
CollectionID: collectionID,
CollectionName: request.CollectionName,
}, nil
}
lct := &loadCollectionTask{
LoadCollectionRequest: &milvuspb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 1,
Timestamp: 1,
SourceID: 1,
TargetID: 1,
},
DbName: dbName,
CollectionName: collectionName,
ReplicaNumber: 1,
},
ctx: ctx,
queryCoord: qc,
datacoord: dc,
result: nil,
collectionID: 0,
}
t.Run("indexcoord describe index error", func(t *testing.T) {
err := lct.Execute(ctx)
assert.Error(t, err)
})
t.Run("indexcoord describe index not success", func(t *testing.T) {
dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) {
return &indexpb.DescribeIndexResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
}, nil
}
err := lct.Execute(ctx)
assert.Error(t, err)
})
t.Run("no vector index", func(t *testing.T) {
dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) {
return &indexpb.DescribeIndexResponse{
Status: merr.Success(),
IndexInfos: []*indexpb.IndexInfo{
{
CollectionID: collectionID,
FieldID: 100,
IndexName: indexName,
IndexID: indexID,
TypeParams: nil,
IndexParams: nil,
IndexedRows: 1025,
TotalRows: 1025,
State: commonpb.IndexState_Finished,
IndexStateFailReason: "",
IsAutoIndex: false,
UserIndexParams: nil,
},
},
}, nil
}
err := lct.Execute(ctx)
assert.Error(t, err)
})
t.Run("not all vector fields with index", func(t *testing.T) {
vecFields := make([]*schemapb.FieldSchema, 0)
for _, field := range newTestSchema().GetFields() {
if typeutil.IsVectorType(field.GetDataType()) {
vecFields = append(vecFields, field)
}
}
assert.GreaterOrEqual(t, len(vecFields), 2)
dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) {
return &indexpb.DescribeIndexResponse{
Status: merr.Success(),
IndexInfos: []*indexpb.IndexInfo{
{
CollectionID: collectionID,
FieldID: vecFields[0].FieldID,
IndexName: indexName,
IndexID: indexID,
TypeParams: nil,
IndexParams: nil,
IndexedRows: 1025,
TotalRows: 1025,
State: commonpb.IndexState_Finished,
IndexStateFailReason: "",
IsAutoIndex: false,
UserIndexParams: nil,
},
},
}, nil
}
err := lct.Execute(ctx)
assert.Error(t, err)
})
}
func Test_loadPartitionTask_Execute(t *testing.T) {
rc := newMockRootCoord()
dc := NewDataCoordMock()
qc := getQueryCoordClient()
qc.EXPECT().ShowPartitions(mock.Anything, mock.Anything).Return(&querypb.ShowPartitionsResponse{
Status: merr.Success(),
PartitionIDs: []int64{},
}, nil)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: merr.Success(),
}, nil)
dbName := funcutil.GenRandomStr()
collectionName := funcutil.GenRandomStr()
collectionID := UniqueID(1)
// fieldName := funcutil.GenRandomStr()
indexName := funcutil.GenRandomStr()
ctx := context.Background()
indexID := int64(1000)
shardMgr := newShardClientMgr()
// failed to get collection id.
_ = InitMetaCache(ctx, rc, qc, shardMgr)
rc.DescribeCollectionFunc = func(ctx context.Context, request *milvuspb.DescribeCollectionRequest, opts ...grpc.CallOption) (*milvuspb.DescribeCollectionResponse, error) {
return &milvuspb.DescribeCollectionResponse{
Status: merr.Success(),
Schema: newTestSchema(),
CollectionID: collectionID,
CollectionName: request.CollectionName,
}, nil
}
lpt := &loadPartitionsTask{
LoadPartitionsRequest: &milvuspb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 1,
Timestamp: 1,
SourceID: 1,
TargetID: 1,
},
DbName: dbName,
CollectionName: collectionName,
ReplicaNumber: 1,
},
ctx: ctx,
queryCoord: qc,
datacoord: dc,
result: nil,
collectionID: 0,
}
t.Run("indexcoord describe index error", func(t *testing.T) {
err := lpt.Execute(ctx)
assert.Error(t, err)
})
t.Run("indexcoord describe index not success", func(t *testing.T) {
dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) {
return &indexpb.DescribeIndexResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "fail reason",
},
}, nil
}
err := lpt.Execute(ctx)
assert.Error(t, err)
})
t.Run("no vector index", func(t *testing.T) {
dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) {
return &indexpb.DescribeIndexResponse{
Status: merr.Success(),
IndexInfos: []*indexpb.IndexInfo{
{
CollectionID: collectionID,
FieldID: 100,
IndexName: indexName,
IndexID: indexID,
TypeParams: nil,
IndexParams: nil,
IndexedRows: 1025,
TotalRows: 1025,
State: commonpb.IndexState_Finished,
IndexStateFailReason: "",
IsAutoIndex: false,
UserIndexParams: nil,
},
},
}, nil
}
err := lpt.Execute(ctx)
assert.Error(t, err)
})
}
func TestCreateResourceGroupTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
qc.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything, mock.Anything).Return(merr.Success(), nil)
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
createRGReq := &milvuspb.CreateResourceGroupRequest{
Base: &commonpb.MsgBase{
MsgID: 1,
Timestamp: 2,
TargetID: 3,
},
ResourceGroup: "rg",
}
task := &CreateResourceGroupTask{
CreateResourceGroupRequest: createRGReq,
ctx: ctx,
queryCoord: qc,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_CreateResourceGroup, task.Type())
assert.Equal(t, UniqueID(1), task.ID())
assert.Equal(t, Timestamp(2), task.BeginTs())
assert.Equal(t, Timestamp(2), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.Base.GetSourceID())
assert.Equal(t, UniqueID(3), task.Base.GetTargetID())
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
}
func TestDropResourceGroupTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
qc.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(merr.Success(), nil)
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
dropRGReq := &milvuspb.DropResourceGroupRequest{
Base: &commonpb.MsgBase{
MsgID: 1,
Timestamp: 2,
TargetID: 3,
},
ResourceGroup: "rg",
}
task := &DropResourceGroupTask{
DropResourceGroupRequest: dropRGReq,
ctx: ctx,
queryCoord: qc,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_DropResourceGroup, task.Type())
assert.Equal(t, UniqueID(1), task.ID())
assert.Equal(t, Timestamp(2), task.BeginTs())
assert.Equal(t, Timestamp(2), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.Base.GetSourceID())
assert.Equal(t, UniqueID(3), task.Base.GetTargetID())
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
}
func TestTransferNodeTask(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
qc.EXPECT().TransferNode(mock.Anything, mock.Anything).Return(merr.Success(), nil)
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
req := &milvuspb.TransferNodeRequest{
Base: &commonpb.MsgBase{
MsgID: 1,
Timestamp: 2,
TargetID: 3,
},
SourceResourceGroup: "rg1",
TargetResourceGroup: "rg2",
NumNode: 1,
}
task := &TransferNodeTask{
TransferNodeRequest: req,
ctx: ctx,
queryCoord: qc,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_TransferNode, task.Type())
assert.Equal(t, UniqueID(1), task.ID())
assert.Equal(t, Timestamp(2), task.BeginTs())
assert.Equal(t, Timestamp(2), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.Base.GetSourceID())
assert.Equal(t, UniqueID(3), task.Base.GetTargetID())
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
}
func TestTransferReplicaTask(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
qc.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(merr.Success(), nil)
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
// make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
req := &milvuspb.TransferReplicaRequest{
Base: &commonpb.MsgBase{
MsgID: 1,
Timestamp: 2,
TargetID: 3,
},
CollectionName: "collection1",
SourceResourceGroup: "rg1",
TargetResourceGroup: "rg2",
NumReplica: 1,
}
task := &TransferReplicaTask{
TransferReplicaRequest: req,
ctx: ctx,
queryCoord: qc,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_TransferReplica, task.Type())
assert.Equal(t, UniqueID(1), task.ID())
assert.Equal(t, Timestamp(2), task.BeginTs())
assert.Equal(t, Timestamp(2), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.Base.GetSourceID())
assert.Equal(t, UniqueID(3), task.Base.GetTargetID())
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.ErrorCode)
}
func TestListResourceGroupsTask(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
qc.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{
Status: merr.Success(),
ResourceGroups: []string{meta.DefaultResourceGroupName, "rg"},
}, nil)
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
req := &milvuspb.ListResourceGroupsRequest{
Base: &commonpb.MsgBase{
MsgID: 1,
Timestamp: 2,
TargetID: 3,
},
}
task := &ListResourceGroupsTask{
ListResourceGroupsRequest: req,
ctx: ctx,
queryCoord: qc,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_ListResourceGroups, task.Type())
assert.Equal(t, UniqueID(1), task.ID())
assert.Equal(t, Timestamp(2), task.BeginTs())
assert.Equal(t, Timestamp(2), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.Base.GetSourceID())
assert.Equal(t, UniqueID(3), task.Base.GetTargetID())
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode())
groups := task.result.GetResourceGroups()
assert.Contains(t, groups, meta.DefaultResourceGroupName)
assert.Contains(t, groups, "rg")
}
func TestDescribeResourceGroupTask(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{
Status: merr.Success(),
ResourceGroup: &querypb.ResourceGroupInfo{
Name: "rg",
Capacity: 2,
NumAvailableNode: 1,
NumOutgoingNode: map[int64]int32{1: 1},
NumIncomingNode: map[int64]int32{2: 2},
},
}, nil)
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
// make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
req := &milvuspb.DescribeResourceGroupRequest{
Base: &commonpb.MsgBase{
MsgID: 1,
Timestamp: 2,
TargetID: 3,
},
ResourceGroup: "rg",
}
task := &DescribeResourceGroupTask{
DescribeResourceGroupRequest: req,
ctx: ctx,
queryCoord: qc,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_DescribeResourceGroup, task.Type())
assert.Equal(t, UniqueID(1), task.ID())
assert.Equal(t, Timestamp(2), task.BeginTs())
assert.Equal(t, Timestamp(2), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.Base.GetSourceID())
assert.Equal(t, UniqueID(3), task.Base.GetTargetID())
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, task.result.GetStatus().GetErrorCode())
groupInfo := task.result.GetResourceGroup()
outgoingNodeNum := groupInfo.GetNumOutgoingNode()
incomingNodeNum := groupInfo.GetNumIncomingNode()
assert.NotNil(t, outgoingNodeNum["collection1"])
assert.NotNil(t, incomingNodeNum["collection2"])
}
func TestDescribeResourceGroupTaskFailed(t *testing.T) {
rc := &MockRootCoordClientInterface{}
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError},
}, nil)
ctx := context.Background()
mgr := newShardClientMgr()
InitMetaCache(ctx, rc, qc, mgr)
// make it avoid remote call on rc
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection1")
globalMetaCache.GetCollectionSchema(context.Background(), GetCurDBNameFromContextOrDefault(ctx), "collection2")
req := &milvuspb.DescribeResourceGroupRequest{
Base: &commonpb.MsgBase{
MsgID: 1,
Timestamp: 2,
TargetID: 3,
},
ResourceGroup: "rgggg",
}
task := &DescribeResourceGroupTask{
DescribeResourceGroupRequest: req,
ctx: ctx,
queryCoord: qc,
}
task.OnEnqueue()
task.PreExecute(ctx)
assert.Equal(t, commonpb.MsgType_DescribeResourceGroup, task.Type())
assert.Equal(t, UniqueID(1), task.ID())
assert.Equal(t, Timestamp(2), task.BeginTs())
assert.Equal(t, Timestamp(2), task.EndTs())
assert.Equal(t, paramtable.GetNodeID(), task.Base.GetSourceID())
assert.Equal(t, UniqueID(3), task.Base.GetTargetID())
err := task.Execute(ctx)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_UnexpectedError, task.result.GetStatus().GetErrorCode())
qc.ExpectedCalls = nil
qc.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&querypb.DescribeResourceGroupResponse{
Status: merr.Success(),
ResourceGroup: &querypb.ResourceGroupInfo{
Name: "rg",
Capacity: 2,
NumAvailableNode: 1,
NumOutgoingNode: map[int64]int32{3: 1},
NumIncomingNode: map[int64]int32{4: 2},
},
}, nil)
err = task.Execute(ctx)
assert.NoError(t, err)
assert.Len(t, task.result.ResourceGroup.NumOutgoingNode, 0)
assert.Len(t, task.result.ResourceGroup.NumIncomingNode, 0)
}
func TestCreateCollectionTaskWithPartitionKey(t *testing.T) {
rc := NewRootCoordMock()
paramtable.Init()
defer rc.Close()
ctx := context.Background()
shardsNum := common.DefaultShardsNum
prefix := "TestCreateCollectionTaskWithPartitionKey"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
int64Field := &schemapb.FieldSchema{
Name: "int64",
DataType: schemapb.DataType_Int64,
IsPrimaryKey: true,
}
varCharField := &schemapb.FieldSchema{
Name: "varChar",
DataType: schemapb.DataType_VarChar,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "max_length",
Value: strconv.Itoa(testMaxVarCharLength),
},
{
Key: "enable_tokenizer",
Value: "true",
},
},
}
floatVecField := &schemapb.FieldSchema{
Name: "fvec",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: strconv.Itoa(testVecDim),
},
},
}
sparseVecField := &schemapb.FieldSchema{
Name: "sparse",
DataType: schemapb.DataType_SparseFloatVector,
}
partitionKeyField := &schemapb.FieldSchema{
Name: "partition_key",
DataType: schemapb.DataType_Int64,
IsPartitionKey: true,
}
schema := &schemapb.CollectionSchema{
Name: collectionName,
Fields: []*schemapb.FieldSchema{int64Field, varCharField, partitionKeyField, floatVecField},
}
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
task := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
t.Run("PreExecute", func(t *testing.T) {
defer Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
var err error
// test default num partitions
err = task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, task.GetNumPartitions())
Params.Save(Params.RootCoordCfg.MaxPartitionNum.Key, "16")
task.NumPartitions = 0
err = task.PreExecute(ctx)
assert.NoError(t, err)
assert.Equal(t, int64(16), task.GetNumPartitions())
Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
// test specify num partition without partition key field
partitionKeyField.IsPartitionKey = false
task.NumPartitions = common.DefaultPartitionsWithPartitionKey * 2
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
partitionKeyField.IsPartitionKey = true
// test multi partition key field
varCharField.IsPartitionKey = true
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
varCharField.IsPartitionKey = false
// test partitions < 0
task.NumPartitions = -2
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
task.NumPartitions = 1000
// test partition key type not in [int64, varChar]
partitionKeyField.DataType = schemapb.DataType_FloatVector
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
partitionKeyField.DataType = schemapb.DataType_Int64
// test partition key set nullable == true
partitionKeyField.Nullable = true
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
partitionKeyField.DataType = schemapb.DataType_Int64
partitionKeyField.Nullable = false
// test partition key field not primary key field
primaryField, _ := typeutil.GetPrimaryFieldSchema(schema)
primaryField.IsPartitionKey = true
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
primaryField.IsPartitionKey = false
// test partition num too large
Params.Save(Params.RootCoordCfg.MaxPartitionNum.Key, "16")
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
Params.Reset(Params.RootCoordCfg.MaxPartitionNum.Key)
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.NoError(t, err)
// test schema with function
// invalid function
schema.Functions = []*schemapb.FunctionSchema{
{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{"invalid name"}},
}
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.Error(t, err)
// normal case
schema.Fields = append(schema.Fields, sparseVecField)
schema.Functions = []*schemapb.FunctionSchema{
{Name: "test", Type: schemapb.FunctionType_BM25, InputFieldNames: []string{varCharField.Name}, OutputFieldNames: []string{sparseVecField.Name}},
}
marshaledSchema, err = proto.Marshal(schema)
assert.NoError(t, err)
task.Schema = marshaledSchema
err = task.PreExecute(ctx)
assert.NoError(t, err)
})
t.Run("Execute", func(t *testing.T) {
err = task.Execute(ctx)
assert.NoError(t, err)
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
// check default partitions
err = InitMetaCache(ctx, rc, qc, nil)
assert.NoError(t, err)
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, "", task.CollectionName)
assert.NoError(t, err)
assert.Equal(t, task.GetNumPartitions(), int64(len(partitionNames)))
createPartitionTask := &createPartitionTask{
Condition: NewTaskCondition(ctx),
CreatePartitionRequest: &milvuspb.CreatePartitionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: "new_partition",
},
ctx: ctx,
rootCoord: rc,
}
err = createPartitionTask.PreExecute(ctx)
assert.Error(t, err)
dropPartitionTask := &dropPartitionTask{
Condition: NewTaskCondition(ctx),
DropPartitionRequest: &milvuspb.DropPartitionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: dbName,
CollectionName: collectionName,
PartitionName: "new_partition",
},
ctx: ctx,
rootCoord: rc,
}
err = dropPartitionTask.PreExecute(ctx)
assert.Error(t, err)
loadPartitionTask := &loadPartitionsTask{
Condition: NewTaskCondition(ctx),
LoadPartitionsRequest: &milvuspb.LoadPartitionsRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: dbName,
CollectionName: collectionName,
PartitionNames: []string{"_default_0"},
},
ctx: ctx,
}
err = loadPartitionTask.PreExecute(ctx)
assert.Error(t, err)
releasePartitionsTask := &releasePartitionsTask{
Condition: NewTaskCondition(ctx),
ReleasePartitionsRequest: &milvuspb.ReleasePartitionsRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: dbName,
CollectionName: collectionName,
PartitionNames: []string{"_default_0"},
},
ctx: ctx,
}
err = releasePartitionsTask.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestPartitionKey(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
shardsNum := common.DefaultShardsNum
prefix := "TestInsertTaskWithPartitionKey"
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
partitionKeyField := &schemapb.FieldSchema{
Name: "partition_key_field",
DataType: schemapb.DataType_Int64,
IsPartitionKey: true,
}
fieldName2Type["partition_key_field"] = schemapb.DataType_Int64
schema.Fields = append(schema.Fields, partitionKeyField)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
t.Run("create collection", func(t *testing.T) {
createCollectionTask := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: "",
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
NumPartitions: common.DefaultPartitionsWithPartitionKey,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = createCollectionTask.PreExecute(ctx)
assert.NoError(t, err)
err = createCollectionTask.Execute(ctx)
assert.NoError(t, err)
})
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, factory)
defer chMgr.removeAllDMLStream()
_, err = chMgr.getOrCreateDmlStream(collectionID)
assert.NoError(t, err)
pchans, err := chMgr.getChannels(collectionID)
assert.NoError(t, err)
interval := time.Millisecond * 10
tso := newMockTsoAllocator()
ticker := newChannelsTimeTicker(ctx, interval, []string{}, newGetStatisticsFunc(pchans), tso)
_ = ticker.start()
defer ticker.close()
idAllocator, err := allocator.NewIDAllocator(ctx, rc, paramtable.GetNodeID())
assert.NoError(t, err)
_ = idAllocator.Start()
defer idAllocator.Close()
segAllocator, err := newSegIDAssigner(ctx, &mockDataCoord{expireTime: Timestamp(2500)}, getLastTick1)
assert.NoError(t, err)
segAllocator.Init()
_ = segAllocator.Start()
defer segAllocator.Close()
partitionNames, err := getDefaultPartitionsInPartitionKeyMode(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
assert.NoError(t, err)
assert.Equal(t, common.DefaultPartitionsWithPartitionKey, int64(len(partitionNames)))
nb := 10
fieldID := common.StartOfUserFieldID
fieldDatas := make([]*schemapb.FieldData, 0)
for fieldName, dataType := range fieldName2Type {
fieldData := generateFieldData(dataType, fieldName, nb)
fieldData.FieldId = int64(fieldID)
fieldDatas = append(fieldDatas, generateFieldData(dataType, fieldName, nb))
fieldID++
}
t.Run("Insert", func(t *testing.T) {
it := &insertTask{
insertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{},
InsertRequest: &msgpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
MsgID: 0,
SourceID: paramtable.GetNodeID(),
},
CollectionName: collectionName,
FieldsData: fieldDatas,
NumRows: uint64(nb),
Version: msgpb.InsertDataVersion_ColumnBased,
},
},
Condition: NewTaskCondition(ctx),
ctx: ctx,
result: &milvuspb.MutationResult{
Status: merr.Success(),
IDs: nil,
SuccIndex: nil,
ErrIndex: nil,
Acknowledged: false,
InsertCnt: 0,
DeleteCnt: 0,
UpsertCnt: 0,
Timestamp: 0,
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
vChannels: nil,
pChannels: nil,
schema: nil,
}
// don't support specify partition name if use partition key
it.insertMsg.PartitionName = partitionNames[0]
assert.Error(t, it.PreExecute(ctx))
it.insertMsg.PartitionName = ""
assert.NoError(t, it.OnEnqueue())
assert.NoError(t, it.PreExecute(ctx))
assert.NoError(t, it.Execute(ctx))
assert.NoError(t, it.PostExecute(ctx))
})
t.Run("Upsert", func(t *testing.T) {
hash := testutils.GenerateHashKeys(nb)
ut := &upsertTask{
ctx: ctx,
Condition: NewTaskCondition(ctx),
baseMsg: msgstream.BaseMsg{
HashValues: hash,
},
req: &milvuspb.UpsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Upsert),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
CollectionName: collectionName,
FieldsData: fieldDatas,
NumRows: uint32(nb),
},
result: &milvuspb.MutationResult{
Status: merr.Success(),
IDs: &schemapb.IDs{
IdField: nil,
},
},
idAllocator: idAllocator,
segIDAssigner: segAllocator,
chMgr: chMgr,
chTicker: ticker,
}
// don't support specify partition name if use partition key
ut.req.PartitionName = partitionNames[0]
assert.Error(t, ut.PreExecute(ctx))
ut.req.PartitionName = ""
assert.NoError(t, ut.OnEnqueue())
assert.NoError(t, ut.PreExecute(ctx))
assert.NoError(t, ut.Execute(ctx))
assert.NoError(t, ut.PostExecute(ctx))
})
t.Run("delete", func(t *testing.T) {
dt := &deleteTask{
Condition: NewTaskCondition(ctx),
req: &milvuspb.DeleteRequest{
CollectionName: collectionName,
Expr: "int64_field in [0, 1]",
},
ctx: ctx,
primaryKeys: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{IntId: &schemapb.LongArray{Data: []int64{0, 1}}},
},
idAllocator: idAllocator,
chMgr: chMgr,
chTicker: ticker,
collectionID: collectionID,
vChannels: []string{"test-channel"},
}
dt.req.PartitionName = ""
assert.NoError(t, dt.PreExecute(ctx))
assert.NoError(t, dt.Execute(ctx))
assert.NoError(t, dt.PostExecute(ctx))
})
t.Run("search", func(t *testing.T) {
searchTask := &searchTask{
ctx: ctx,
SearchRequest: &internalpb.SearchRequest{
Base: &commonpb.MsgBase{},
},
request: &milvuspb.SearchRequest{
CollectionName: collectionName,
Nq: 1,
},
qc: qc,
tr: timerecord.NewTimeRecorder("test-search"),
}
// don't support specify partition name if use partition key
searchTask.request.PartitionNames = partitionNames
err = searchTask.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("query", func(t *testing.T) {
queryTask := &queryTask{
ctx: ctx,
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{},
},
request: &milvuspb.QueryRequest{
CollectionName: collectionName,
},
qc: qc,
}
// don't support specify partition name if use partition key
queryTask.request.PartitionNames = partitionNames
err = queryTask.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestClusteringKey(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
ctx := context.Background()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
shardsNum := common.DefaultShardsNum
prefix := "TestClusteringKey"
collectionName := prefix + funcutil.GenRandomStr()
t.Run("create collection normal", func(t *testing.T) {
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
fieldName2Type["cluster_key_field"] = schemapb.DataType_Int64
clusterKeyField := &schemapb.FieldSchema{
Name: "cluster_key_field",
DataType: schemapb.DataType_Int64,
IsClusteringKey: true,
}
schema.Fields = append(schema.Fields, clusterKeyField)
vecField := &schemapb.FieldSchema{
Name: "fvec_field",
DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{
Key: common.DimKey,
Value: strconv.Itoa(testVecDim),
},
},
}
schema.Fields = append(schema.Fields, vecField)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionTask := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: "",
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = createCollectionTask.PreExecute(ctx)
assert.NoError(t, err)
err = createCollectionTask.Execute(ctx)
assert.NoError(t, err)
})
t.Run("create collection not support more than one clustering key", func(t *testing.T) {
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
fieldName2Type["cluster_key_field"] = schemapb.DataType_Int64
clusterKeyField := &schemapb.FieldSchema{
Name: "cluster_key_field",
DataType: schemapb.DataType_Int64,
IsClusteringKey: true,
}
schema.Fields = append(schema.Fields, clusterKeyField)
clusterKeyField2 := &schemapb.FieldSchema{
Name: "cluster_key_field2",
DataType: schemapb.DataType_Int64,
IsClusteringKey: true,
}
schema.Fields = append(schema.Fields, clusterKeyField2)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionTask := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: "",
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = createCollectionTask.PreExecute(ctx)
assert.Error(t, err)
})
t.Run("create collection with vector clustering key", func(t *testing.T) {
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["int64_field"] = schemapb.DataType_Int64
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Type, "int64_field", false)
clusterKeyField := &schemapb.FieldSchema{
Name: "vec_field",
DataType: schemapb.DataType_FloatVector,
IsClusteringKey: true,
}
schema.Fields = append(schema.Fields, clusterKeyField)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionTask := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: "",
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
err = createCollectionTask.PreExecute(ctx)
assert.Error(t, err)
})
}
func TestAlterCollectionCheckLoaded(t *testing.T) {
rc := NewRootCoordMock()
rc.state.Store(commonpb.StateCode_Healthy)
qc := &mocks.MockQueryCoordClient{}
InitMetaCache(context.Background(), rc, qc, nil)
collectionName := "test_alter_collection_check_loaded"
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: collectionName,
Schema: nil,
ShardsNum: 1,
}
rc.CreateCollection(context.Background(), createColReq)
resp, err := rc.DescribeCollection(context.Background(), &milvuspb.DescribeCollectionRequest{CollectionName: collectionName})
assert.NoError(t, err)
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
CollectionIDs: []int64{resp.CollectionID},
InMemoryPercentages: []int64{100},
}, nil)
task := &alterCollectionTask{
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{},
CollectionName: collectionName,
Properties: []*commonpb.KeyValuePair{{Key: common.MmapEnabledKey, Value: "true"}},
},
queryCoord: qc,
}
err = task.PreExecute(context.Background())
assert.Equal(t, merr.Code(merr.ErrCollectionLoaded), merr.Code(err))
}
func TestTaskPartitionKeyIsolation(t *testing.T) {
rc := NewRootCoordMock()
defer rc.Close()
dc := NewDataCoordMock()
defer dc.Close()
qc := getQueryCoordClient()
qc.EXPECT().ShowCollections(mock.Anything, mock.Anything).Return(&querypb.ShowCollectionsResponse{}, nil).Maybe()
defer qc.Close()
ctx := context.Background()
mgr := newShardClientMgr()
err := InitMetaCache(ctx, rc, qc, mgr)
assert.NoError(t, err)
shardsNum := common.DefaultShardsNum
prefix := "TestPartitionKeyIsolation"
collectionName := prefix + funcutil.GenRandomStr()
getSchema := func(colName string, hasPartitionKey bool) *schemapb.CollectionSchema {
fieldName2Type := make(map[string]schemapb.DataType)
fieldName2Type["fvec_field"] = schemapb.DataType_FloatVector
fieldName2Type["varChar_field"] = schemapb.DataType_VarChar
fieldName2Type["int64_field"] = schemapb.DataType_Int64
schema := constructCollectionSchemaByDataType(colName, fieldName2Type, "int64_field", false)
if hasPartitionKey {
partitionKeyField := &schemapb.FieldSchema{
Name: "partition_key_field",
DataType: schemapb.DataType_Int64,
IsPartitionKey: true,
}
fieldName2Type["partition_key_field"] = schemapb.DataType_Int64
schema.Fields = append(schema.Fields, partitionKeyField)
}
return schema
}
getCollectionTask := func(colName string, isIso bool, marshaledSchema []byte) *createCollectionTask {
isoStr := "false"
if isIso {
isoStr = "true"
}
return &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
Timestamp: Timestamp(time.Now().UnixNano()),
},
DbName: "",
CollectionName: colName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
Properties: []*commonpb.KeyValuePair{{Key: common.PartitionKeyIsolationKey, Value: isoStr}},
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
}
createIsoCollection := func(colName string, hasPartitionKey bool, isIsolation bool, isIsoNil bool) {
isoStr := "false"
if isIsolation {
isoStr = "true"
}
schema := getSchema(colName, hasPartitionKey)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createColReq := &milvuspb.CreateCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropCollection,
MsgID: 100,
Timestamp: 100,
},
DbName: dbName,
CollectionName: colName,
Schema: marshaledSchema,
ShardsNum: 1,
Properties: []*commonpb.KeyValuePair{{Key: common.PartitionKeyIsolationKey, Value: isoStr}},
}
if isIsoNil {
createColReq.Properties = nil
}
stats, err := rc.CreateCollection(ctx, createColReq)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, stats.ErrorCode)
}
getAlterCollectionTask := func(colName string, isIsolation bool) *alterCollectionTask {
isoStr := "false"
if isIsolation {
isoStr = "true"
}
return &alterCollectionTask{
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{},
CollectionName: colName,
Properties: []*commonpb.KeyValuePair{{Key: common.PartitionKeyIsolationKey, Value: isoStr}},
},
queryCoord: qc,
dataCoord: dc,
}
}
t.Run("create collection valid", func(t *testing.T) {
paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true")
defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false")
schema := getSchema(collectionName, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionTask := getCollectionTask(collectionName, true, marshaledSchema)
err = createCollectionTask.PreExecute(ctx)
assert.NoError(t, err)
err = createCollectionTask.Execute(ctx)
assert.NoError(t, err)
})
t.Run("create collection without isolation", func(t *testing.T) {
schema := getSchema(collectionName, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionTask := getCollectionTask(collectionName, false, marshaledSchema)
err = createCollectionTask.PreExecute(ctx)
assert.NoError(t, err)
err = createCollectionTask.Execute(ctx)
assert.NoError(t, err)
})
t.Run("create collection isolation but no partition key", func(t *testing.T) {
schema := getSchema(collectionName, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionTask := getCollectionTask(collectionName, true, marshaledSchema)
assert.ErrorContains(t, createCollectionTask.PreExecute(ctx), "partition key isolation mode is enabled but no partition key field is set")
})
t.Run("create collection with isolation and partition key but MV is not enabled", func(t *testing.T) {
paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false")
schema := getSchema(collectionName, true)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
createCollectionTask := getCollectionTask(collectionName, true, marshaledSchema)
assert.ErrorContains(t, createCollectionTask.PreExecute(ctx), "partition key isolation mode is enabled but current Milvus does not support it")
})
t.Run("alter collection from valid", func(t *testing.T) {
paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true")
defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false")
colName := collectionName + "AlterValid"
createIsoCollection(colName, true, false, false)
alterTask := getAlterCollectionTask(colName, true)
err := alterTask.PreExecute(ctx)
assert.NoError(t, err)
})
t.Run("alter collection without isolation", func(t *testing.T) {
colName := collectionName + "AlterNoIso"
createIsoCollection(colName, true, false, true)
alterTask := alterCollectionTask{
AlterCollectionRequest: &milvuspb.AlterCollectionRequest{
Base: &commonpb.MsgBase{},
CollectionName: colName,
Properties: nil,
},
queryCoord: qc,
}
err := alterTask.PreExecute(ctx)
assert.NoError(t, err)
})
t.Run("alter collection isolation but no partition key", func(t *testing.T) {
colName := collectionName + "AlterNoPartkey"
createIsoCollection(colName, false, false, false)
alterTask := getAlterCollectionTask(colName, true)
assert.ErrorContains(t, alterTask.PreExecute(ctx), "partition key isolation mode is enabled but no partition key field is set")
})
t.Run("alter collection with isolation and partition key but MV is not enabled", func(t *testing.T) {
paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false")
colName := collectionName + "AlterNoMv"
createIsoCollection(colName, true, false, false)
alterTask := getAlterCollectionTask(colName, true)
assert.ErrorContains(t, alterTask.PreExecute(ctx), "partition key isolation mode is enabled but current Milvus does not support it")
})
t.Run("alter collection with vec index and isolation", func(t *testing.T) {
paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("true")
defer paramtable.Get().CommonCfg.EnableMaterializedView.SwapTempValue("false")
colName := collectionName + "AlterVecIndex"
createIsoCollection(colName, true, true, false)
resp, err := rc.DescribeCollection(ctx, &milvuspb.DescribeCollectionRequest{DbName: dbName, CollectionName: colName})
assert.NoError(t, err)
var vecFieldID int64 = 0
for _, field := range resp.Schema.Fields {
if field.DataType == schemapb.DataType_FloatVector {
vecFieldID = field.FieldID
break
}
}
assert.NotEqual(t, vecFieldID, int64(0))
dc.DescribeIndexFunc = func(ctx context.Context, request *indexpb.DescribeIndexRequest, opts ...grpc.CallOption) (*indexpb.DescribeIndexResponse, error) {
return &indexpb.DescribeIndexResponse{
Status: merr.Success(),
IndexInfos: []*indexpb.IndexInfo{
{
FieldID: vecFieldID,
},
},
}, nil
}
alterTask := getAlterCollectionTask(colName, false)
assert.ErrorContains(t, alterTask.PreExecute(ctx),
"can not alter partition key isolation mode if the collection already has a vector index. Please drop the index first")
})
}