mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
Cherry-pick from master pr: #45018 #45030 Related to #44761 Refactor proxy shard client management by creating a new internal/proxy/shardclient package. This improves code organization and modularity by: - Moving load balancing logic (LookAsideBalancer, RoundRobinBalancer) to shardclient package - Extracting shard client manager and related interfaces into separate package - Relocating shard leader management and client lifecycle code - Adding package documentation (README.md, OWNERS) - Updating proxy code to use the new shardclient package interfaces This change makes the shard client functionality more maintainable and better encapsulated, reducing coupling in the proxy layer. Also consolidates the proxy package's mockery generation to use a centralized `.mockery.yaml` configuration file, aligning with the pattern used by other packages like querycoordv2. Changes - **Makefile**: Replace multiple individual mockery commands with a single config-based invocation for `generate-mockery-proxy` target - **internal/proxy/.mockery.yaml**: Add mockery configuration defining all mock interfaces for proxy and proxy/shardclient packages - **Mock files**: Regenerate mocks using the new configuration: - `mock_cache.go`: Clean up by removing unused interface methods (credential, shard cache, policy methods) - `shardclient/mock_lb_balancer.go`: Update type comments (nodeInfo → NodeInfo) - `shardclient/mock_lb_policy.go`: Update formatting - `shardclient/mock_shardclient_manager.go`: Fix parameter naming consistency (nodeInfo1 → nodeInfo) - **task_search_test.go**: Remove obsolete mock expectations for deprecated cache methods Benefits - Centralized mockery configuration for easier maintenance - Consistent with other packages (querycoordv2, etc.) - Cleaner mock interfaces by removing unused methods - Better type consistency in generated mocks --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
1304 lines
39 KiB
Go
1304 lines
39 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/mock"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
|
"github.com/milvus-io/milvus/internal/util/reduce"
|
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/testutils"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
func TestQueryTask_all(t *testing.T) {
|
|
var (
|
|
err error
|
|
ctx = context.TODO()
|
|
|
|
qc = NewMixCoordMock()
|
|
qn = getQueryNodeClient()
|
|
|
|
shardsNum = common.DefaultShardsNum
|
|
collectionName = t.Name() + funcutil.GenRandomStr()
|
|
|
|
expr = fmt.Sprintf("%s > 0", testInt64Field)
|
|
hitNum = 10
|
|
)
|
|
|
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
|
|
|
mgr := shardclient.NewMockShardClientManager(t)
|
|
mgr.EXPECT().GetClient(mock.Anything, mock.Anything).Return(qn, nil).Maybe()
|
|
mgr.EXPECT().GetShardLeaderList(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{"mock_qn"}, nil).Maybe()
|
|
mgr.EXPECT().GetShard(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]shardclient.NodeInfo{
|
|
{NodeID: 1, Address: "mock_qn", Serviceable: true},
|
|
}, nil).Maybe()
|
|
mgr.EXPECT().DeprecateShardCache(mock.Anything, mock.Anything).Return().Maybe()
|
|
lb := shardclient.NewLBPolicyImpl(mgr)
|
|
|
|
err = InitMetaCache(ctx, qc)
|
|
assert.NoError(t, err)
|
|
|
|
fieldName2Types := map[string]schemapb.DataType{
|
|
testBoolField: schemapb.DataType_Bool,
|
|
testInt32Field: schemapb.DataType_Int32,
|
|
testInt64Field: schemapb.DataType_Int64,
|
|
testFloatField: schemapb.DataType_Float,
|
|
testDoubleField: schemapb.DataType_Double,
|
|
testFloatVecField: schemapb.DataType_FloatVector,
|
|
}
|
|
if enableMultipleVectorFields {
|
|
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
|
|
}
|
|
|
|
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
|
|
marshaledSchema, err := proto.Marshal(schema)
|
|
assert.NoError(t, err)
|
|
|
|
createColT := &createCollectionTask{
|
|
Condition: NewTaskCondition(ctx),
|
|
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
|
CollectionName: collectionName,
|
|
Schema: marshaledSchema,
|
|
ShardsNum: shardsNum,
|
|
},
|
|
ctx: ctx,
|
|
mixCoord: qc,
|
|
}
|
|
|
|
require.NoError(t, createColT.OnEnqueue())
|
|
require.NoError(t, createColT.PreExecute(ctx))
|
|
require.NoError(t, createColT.Execute(ctx))
|
|
require.NoError(t, createColT.PostExecute(ctx))
|
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, GetCurDBNameFromContextOrDefault(ctx), collectionName)
|
|
assert.NoError(t, err)
|
|
|
|
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_LoadCollection,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionID: collectionID,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
|
|
|
t.Run("test query task parameters", func(t *testing.T) {
|
|
task := &queryTask{
|
|
Condition: NewTaskCondition(ctx),
|
|
RetrieveRequest: &internalpb.RetrieveRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionID: collectionID,
|
|
OutputFieldsId: make([]int64, len(fieldName2Types)),
|
|
},
|
|
ctx: ctx,
|
|
result: &milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{},
|
|
},
|
|
request: &milvuspb.QueryRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionName: collectionName,
|
|
Expr: expr,
|
|
QueryParams: []*commonpb.KeyValuePair{
|
|
{
|
|
Key: IgnoreGrowingKey,
|
|
Value: "false",
|
|
},
|
|
},
|
|
},
|
|
mixCoord: qc,
|
|
lb: lb,
|
|
shardclientMgr: mgr,
|
|
}
|
|
|
|
assert.NoError(t, task.OnEnqueue())
|
|
|
|
// test query task with timeout
|
|
ctx1, cancel1 := context.WithTimeout(ctx, 10*time.Second)
|
|
defer cancel1()
|
|
// before preExecute
|
|
assert.Equal(t, typeutil.ZeroTimestamp, task.TimeoutTimestamp)
|
|
task.ctx = ctx1
|
|
assert.NoError(t, task.PreExecute(ctx))
|
|
|
|
{
|
|
task.mustUsePartitionKey = true
|
|
err := task.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
|
task.mustUsePartitionKey = false
|
|
}
|
|
|
|
// after preExecute
|
|
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
|
|
|
|
// check reduce_stop_for_best
|
|
assert.Equal(t, false, task.RetrieveRequest.GetReduceStopForBest())
|
|
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
|
Key: ReduceStopForBestKey,
|
|
Value: "trxxxx",
|
|
})
|
|
assert.Error(t, task.PreExecute(ctx))
|
|
task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1]
|
|
|
|
// check parse collection id
|
|
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
|
Key: CollectionID,
|
|
Value: "trxxxx",
|
|
})
|
|
err := task.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
task.request.QueryParams = task.request.QueryParams[0 : len(task.request.QueryParams)-1]
|
|
|
|
// check collection id consistency
|
|
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
|
Key: LimitKey,
|
|
Value: "11",
|
|
})
|
|
task.request.QueryParams = append(task.request.QueryParams, &commonpb.KeyValuePair{
|
|
Key: CollectionID,
|
|
Value: "8080",
|
|
})
|
|
err = task.PreExecute(ctx)
|
|
assert.Error(t, err)
|
|
task.request.QueryParams = make([]*commonpb.KeyValuePair, 0)
|
|
|
|
result1 := &internalpb.RetrieveResults{
|
|
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
|
|
Status: merr.Success(),
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{Data: testutils.GenerateInt64Array(hitNum)},
|
|
},
|
|
},
|
|
}
|
|
|
|
outputFieldIDs := make([]UniqueID, 0, len(fieldName2Types))
|
|
for i := 0; i < len(fieldName2Types); i++ {
|
|
outputFieldIDs = append(outputFieldIDs, int64(common.StartOfUserFieldID+i))
|
|
}
|
|
task.RetrieveRequest.OutputFieldsId = outputFieldIDs
|
|
for fieldName, dataType := range fieldName2Types {
|
|
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, hitNum))
|
|
}
|
|
result1.FieldsData = append(result1.FieldsData, generateFieldData(schemapb.DataType_Int64, common.TimeStampFieldName, hitNum))
|
|
task.RetrieveRequest.OutputFieldsId = append(task.RetrieveRequest.OutputFieldsId, common.TimeStampField)
|
|
task.ctx = ctx
|
|
qn.ExpectedCalls = nil
|
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
|
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(nil, errors.New("mock error"))
|
|
assert.Error(t, task.Execute(ctx))
|
|
|
|
qn.ExpectedCalls = nil
|
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
|
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
|
Status: merr.Status(merr.ErrChannelNotAvailable),
|
|
}, nil)
|
|
err = task.Execute(ctx)
|
|
assert.ErrorIs(t, err, merr.ErrChannelNotAvailable)
|
|
|
|
qn.ExpectedCalls = nil
|
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
|
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(&internalpb.RetrieveResults{
|
|
Status: &commonpb.Status{
|
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
|
},
|
|
}, nil)
|
|
assert.Error(t, task.Execute(ctx))
|
|
|
|
qn.ExpectedCalls = nil
|
|
qn.EXPECT().GetComponentStates(mock.Anything, mock.Anything).Return(nil, nil).Maybe()
|
|
qn.EXPECT().Query(mock.Anything, mock.Anything).Return(result1, nil)
|
|
assert.NoError(t, task.Execute(ctx))
|
|
|
|
task.queryParams = &queryParams{
|
|
limit: 100,
|
|
offset: 100,
|
|
}
|
|
assert.NoError(t, task.PostExecute(ctx))
|
|
|
|
for i := 0; i < len(task.result.FieldsData); i++ {
|
|
assert.NotEqual(t, task.result.FieldsData[i].FieldId, common.TimeStampField)
|
|
}
|
|
})
|
|
|
|
t.Run("test query for iterator", func(t *testing.T) {
|
|
qt := &queryTask{
|
|
Condition: NewTaskCondition(ctx),
|
|
RetrieveRequest: &internalpb.RetrieveRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionID: collectionID,
|
|
OutputFieldsId: make([]int64, len(fieldName2Types)),
|
|
},
|
|
ctx: ctx,
|
|
result: &milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{},
|
|
},
|
|
request: &milvuspb.QueryRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionName: collectionName,
|
|
Expr: expr,
|
|
QueryParams: []*commonpb.KeyValuePair{
|
|
{
|
|
Key: IgnoreGrowingKey,
|
|
Value: "false",
|
|
},
|
|
{
|
|
Key: IteratorField,
|
|
Value: "True",
|
|
},
|
|
},
|
|
},
|
|
mixCoord: qc,
|
|
lb: lb,
|
|
shardclientMgr: mgr,
|
|
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
|
}
|
|
// simulate scheduler enqueue task
|
|
enqueTs := uint64(10000)
|
|
qt.SetTs(enqueTs)
|
|
qtErr := qt.PreExecute(context.TODO())
|
|
assert.Nil(t, qtErr)
|
|
assert.True(t, qt.queryParams.isIterator)
|
|
qt.resultBuf.Insert(&internalpb.RetrieveResults{})
|
|
qtErr = qt.PostExecute(context.TODO())
|
|
assert.Nil(t, qtErr)
|
|
// after first page, sessionTs is set
|
|
assert.True(t, qt.result.GetSessionTs() > 0)
|
|
|
|
// next page query task
|
|
qt = &queryTask{
|
|
Condition: NewTaskCondition(ctx),
|
|
RetrieveRequest: &internalpb.RetrieveRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionID: collectionID,
|
|
OutputFieldsId: make([]int64, len(fieldName2Types)),
|
|
},
|
|
ctx: ctx,
|
|
result: &milvuspb.QueryResults{
|
|
Status: merr.Success(),
|
|
FieldsData: []*schemapb.FieldData{},
|
|
},
|
|
request: &milvuspb.QueryRequest{
|
|
Base: &commonpb.MsgBase{
|
|
MsgType: commonpb.MsgType_Retrieve,
|
|
SourceID: paramtable.GetNodeID(),
|
|
},
|
|
CollectionName: collectionName,
|
|
Expr: expr,
|
|
QueryParams: []*commonpb.KeyValuePair{
|
|
{
|
|
Key: IgnoreGrowingKey,
|
|
Value: "false",
|
|
},
|
|
{
|
|
Key: IteratorField,
|
|
Value: "True",
|
|
},
|
|
},
|
|
GuaranteeTimestamp: enqueTs,
|
|
},
|
|
mixCoord: qc,
|
|
lb: lb,
|
|
shardclientMgr: mgr,
|
|
resultBuf: &typeutil.ConcurrentSet[*internalpb.RetrieveResults]{},
|
|
}
|
|
qtErr = qt.PreExecute(context.TODO())
|
|
assert.Nil(t, qtErr)
|
|
assert.True(t, qt.queryParams.isIterator)
|
|
// from the second page, the mvccTs is set to the sessionTs init in the first page
|
|
assert.Equal(t, enqueTs, qt.GetMvccTimestamp())
|
|
})
|
|
}
|
|
|
|
func Test_translateToOutputFieldIDs(t *testing.T) {
|
|
type testCases struct {
|
|
name string
|
|
outputFields []string
|
|
schema *schemapb.CollectionSchema
|
|
expectedError bool
|
|
expectedIDs []int64
|
|
}
|
|
|
|
cases := []testCases{
|
|
{
|
|
name: "empty output fields",
|
|
outputFields: []string{},
|
|
schema: &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: common.RowIDField,
|
|
Name: common.RowIDFieldName,
|
|
},
|
|
{
|
|
FieldID: 100,
|
|
Name: "ID",
|
|
IsPrimaryKey: true,
|
|
},
|
|
{
|
|
FieldID: 101,
|
|
Name: "Vector",
|
|
},
|
|
},
|
|
},
|
|
expectedError: false,
|
|
expectedIDs: []int64{100},
|
|
},
|
|
{
|
|
name: "nil output fields",
|
|
outputFields: nil,
|
|
schema: &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: common.RowIDField,
|
|
Name: common.RowIDFieldName,
|
|
},
|
|
{
|
|
FieldID: 100,
|
|
Name: "ID",
|
|
IsPrimaryKey: true,
|
|
},
|
|
{
|
|
FieldID: 101,
|
|
Name: "Vector",
|
|
},
|
|
},
|
|
},
|
|
expectedError: false,
|
|
expectedIDs: []int64{100},
|
|
},
|
|
{
|
|
name: "full list",
|
|
outputFields: []string{"ID", "Vector"},
|
|
schema: &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: common.RowIDField,
|
|
Name: common.RowIDFieldName,
|
|
},
|
|
{
|
|
FieldID: 100,
|
|
Name: "ID",
|
|
IsPrimaryKey: true,
|
|
},
|
|
{
|
|
FieldID: 101,
|
|
Name: "Vector",
|
|
},
|
|
},
|
|
},
|
|
expectedError: false,
|
|
expectedIDs: []int64{100, 101},
|
|
},
|
|
{
|
|
name: "vector only",
|
|
outputFields: []string{"Vector"},
|
|
schema: &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: common.RowIDField,
|
|
Name: common.RowIDFieldName,
|
|
},
|
|
{
|
|
FieldID: 100,
|
|
Name: "ID",
|
|
IsPrimaryKey: true,
|
|
},
|
|
{
|
|
FieldID: 101,
|
|
Name: "Vector",
|
|
},
|
|
},
|
|
},
|
|
expectedError: false,
|
|
expectedIDs: []int64{101, 100},
|
|
},
|
|
{
|
|
name: "with field not exist",
|
|
outputFields: []string{"ID", "Vector", "Extra"},
|
|
schema: &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: common.RowIDField,
|
|
Name: common.RowIDFieldName,
|
|
},
|
|
{
|
|
FieldID: 100,
|
|
Name: "ID",
|
|
IsPrimaryKey: true,
|
|
},
|
|
{
|
|
FieldID: 101,
|
|
Name: "Vector",
|
|
},
|
|
},
|
|
},
|
|
expectedError: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range cases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
ids, err := translateToOutputFieldIDs(tc.outputFields, tc.schema)
|
|
if tc.expectedError {
|
|
assert.Error(t, err)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
require.Equal(t, len(tc.expectedIDs), len(ids))
|
|
for idx, expectedID := range tc.expectedIDs {
|
|
assert.Equal(t, expectedID, ids[idx])
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestTaskQuery_functions(t *testing.T) {
|
|
t.Run("test parseQueryParams", func(t *testing.T) {
|
|
tests := []struct {
|
|
description string
|
|
|
|
inKey []string
|
|
inValue []string
|
|
|
|
expectErr bool
|
|
outLimit int64
|
|
outOffset int64
|
|
}{
|
|
{"empty input", []string{}, []string{}, false, typeutil.Unlimited, 0},
|
|
{"valid limit=1", []string{LimitKey}, []string{"1"}, false, 1, 0},
|
|
{"valid limit=1, offset=2", []string{LimitKey, OffsetKey}, []string{"1", "2"}, false, 1, 2},
|
|
{"valid no limit, offset=2", []string{OffsetKey}, []string{"2"}, false, typeutil.Unlimited, 0},
|
|
{"invalid limit str", []string{LimitKey}, []string{"a"}, true, 0, 0},
|
|
{"invalid limit zero", []string{LimitKey}, []string{"0"}, true, 0, 0},
|
|
{"invalid limit negative", []string{LimitKey}, []string{"-1"}, true, 0, 0},
|
|
{"invalid limit 16385", []string{LimitKey}, []string{"16385"}, true, 0, 0},
|
|
{"invalid offset negative", []string{LimitKey, OffsetKey}, []string{"1", "-1"}, true, 0, 0},
|
|
{"invalid offset 16385", []string{LimitKey, OffsetKey}, []string{"1", "16385"}, true, 0, 0},
|
|
{"invalid limit=16384 offset=16384", []string{LimitKey, OffsetKey}, []string{"16384", "16384"}, true, 0, 0},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run(test.description, func(t *testing.T) {
|
|
var inParams []*commonpb.KeyValuePair
|
|
for i := range test.inKey {
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: test.inKey[i],
|
|
Value: test.inValue[i],
|
|
})
|
|
}
|
|
ret, err := parseQueryParams(inParams)
|
|
if test.expectErr {
|
|
assert.Error(t, err)
|
|
assert.Empty(t, ret)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, test.outLimit, ret.limit)
|
|
assert.Equal(t, test.outOffset, ret.offset)
|
|
}
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("test parseQueryParams for reduce type", func(t *testing.T) {
|
|
{
|
|
var inParams []*commonpb.KeyValuePair
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: ReduceStopForBestKey,
|
|
Value: "True",
|
|
})
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: IteratorField,
|
|
Value: "True",
|
|
})
|
|
ret, err := parseQueryParams(inParams)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, reduce.IReduceInOrderForBest, ret.reduceType)
|
|
}
|
|
{
|
|
var inParams []*commonpb.KeyValuePair
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: ReduceStopForBestKey,
|
|
Value: "True",
|
|
})
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: IteratorField,
|
|
Value: "TrueXXXX",
|
|
})
|
|
ret, err := parseQueryParams(inParams)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, ret)
|
|
}
|
|
{
|
|
var inParams []*commonpb.KeyValuePair
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: ReduceStopForBestKey,
|
|
Value: "TrueXXXXX",
|
|
})
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: IteratorField,
|
|
Value: "True",
|
|
})
|
|
ret, err := parseQueryParams(inParams)
|
|
assert.Error(t, err)
|
|
assert.Nil(t, ret)
|
|
}
|
|
{
|
|
var inParams []*commonpb.KeyValuePair
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: ReduceStopForBestKey,
|
|
Value: "True",
|
|
})
|
|
// when not setting iterator tag, ignore reduce_stop_for_best
|
|
ret, err := parseQueryParams(inParams)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType)
|
|
}
|
|
{
|
|
var inParams []*commonpb.KeyValuePair
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: IteratorField,
|
|
Value: "True",
|
|
})
|
|
// when not setting reduce_stop_for_best tag, reduce by keep results in order
|
|
ret, err := parseQueryParams(inParams)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, reduce.IReduceInOrder, ret.reduceType)
|
|
}
|
|
{
|
|
var inParams []*commonpb.KeyValuePair
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: ReduceStopForBestKey,
|
|
Value: "False",
|
|
})
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: IteratorField,
|
|
Value: "True",
|
|
})
|
|
ret, err := parseQueryParams(inParams)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, reduce.IReduceInOrder, ret.reduceType)
|
|
}
|
|
{
|
|
var inParams []*commonpb.KeyValuePair
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: ReduceStopForBestKey,
|
|
Value: "False",
|
|
})
|
|
inParams = append(inParams, &commonpb.KeyValuePair{
|
|
Key: IteratorField,
|
|
Value: "False",
|
|
})
|
|
ret, err := parseQueryParams(inParams)
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, reduce.IReduceNoOrder, ret.reduceType)
|
|
}
|
|
})
|
|
|
|
t.Run("test reduceRetrieveResults", func(t *testing.T) {
|
|
const (
|
|
Dim = 8
|
|
Int64FieldName = "Int64Field"
|
|
FloatVectorFieldName = "FloatVectorField"
|
|
Int64FieldID = common.StartOfUserFieldID + 1
|
|
FloatVectorFieldID = common.StartOfUserFieldID + 2
|
|
)
|
|
Int64Array := []int64{11, 22}
|
|
FloatVector := []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0}
|
|
|
|
var fieldDataArray1 []*schemapb.FieldData
|
|
fieldDataArray1 = append(fieldDataArray1, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
|
|
fieldDataArray1 = append(fieldDataArray1, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
|
|
|
|
var fieldDataArray2 []*schemapb.FieldData
|
|
fieldDataArray2 = append(fieldDataArray2, getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, Int64Array[0:2], 1))
|
|
fieldDataArray2 = append(fieldDataArray2, getFieldData(FloatVectorFieldName, FloatVectorFieldID, schemapb.DataType_FloatVector, FloatVector[0:16], Dim))
|
|
|
|
t.Run("test nil results", func(t *testing.T) {
|
|
ret, err := reduceRetrieveResults(context.Background(), nil, &queryParams{})
|
|
assert.NoError(t, err)
|
|
assert.Empty(t, ret.GetFieldsData())
|
|
})
|
|
|
|
t.Run("test merge", func(t *testing.T) {
|
|
r1 := &internalpb.RetrieveResults{
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 3},
|
|
},
|
|
},
|
|
},
|
|
FieldsData: fieldDataArray1,
|
|
}
|
|
r2 := &internalpb.RetrieveResults{
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{2, 4},
|
|
},
|
|
},
|
|
},
|
|
FieldsData: fieldDataArray2,
|
|
}
|
|
|
|
resultFloat := []float32{
|
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
|
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
|
|
11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0,
|
|
11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0,
|
|
}
|
|
|
|
t.Run("test limited", func(t *testing.T) {
|
|
tests := []struct {
|
|
description string
|
|
limit int64
|
|
}{
|
|
{"limit 1", 1},
|
|
{"limit 2", 2},
|
|
{"limit 3", 3},
|
|
{"limit 4", 4},
|
|
}
|
|
resultField0 := []int64{11, 11, 22, 22}
|
|
for _, test := range tests {
|
|
t.Run(test.description, func(t *testing.T) {
|
|
result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: test.limit})
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
assert.Equal(t, resultField0[0:test.limit], result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
assert.InDeltaSlice(t, resultFloat[0:test.limit*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
|
|
assert.NoError(t, err)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("test unLimited and maxOutputSize", func(t *testing.T) {
|
|
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1")
|
|
|
|
ids := make([]int64, 100)
|
|
offsets := make([]int64, 100)
|
|
for i := range ids {
|
|
ids[i] = int64(i)
|
|
offsets[i] = int64(i)
|
|
}
|
|
fieldData := getFieldData(Int64FieldName, Int64FieldID, schemapb.DataType_Int64, ids, 1)
|
|
|
|
result := &internalpb.RetrieveResults{
|
|
Ids: &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: ids,
|
|
},
|
|
},
|
|
},
|
|
FieldsData: []*schemapb.FieldData{fieldData},
|
|
}
|
|
|
|
_, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{result}, &queryParams{limit: typeutil.Unlimited})
|
|
assert.Error(t, err)
|
|
paramtable.Get().Save(paramtable.Get().QuotaConfig.MaxOutputSize.Key, "1104857600")
|
|
})
|
|
|
|
t.Run("test offset", func(t *testing.T) {
|
|
tests := []struct {
|
|
description string
|
|
offset int64
|
|
}{
|
|
{"offset 0", 0},
|
|
{"offset 1", 1},
|
|
{"offset 2", 2},
|
|
{"offset 3", 3},
|
|
}
|
|
resultField0 := []int64{11, 11, 22, 22}
|
|
for _, test := range tests {
|
|
t.Run(test.description, func(t *testing.T) {
|
|
result, err := reduceRetrieveResults(context.Background(), []*internalpb.RetrieveResults{r1, r2}, &queryParams{limit: 1, offset: test.offset})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
assert.Equal(t, resultField0[test.offset:test.offset+1], result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
assert.InDeltaSlice(t, resultFloat[test.offset*Dim:(test.offset+1)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
|
|
})
|
|
}
|
|
})
|
|
|
|
t.Run("test stop reduce for best for limit", func(t *testing.T) {
|
|
r1.HasMoreResult = true
|
|
r2.HasMoreResult = false
|
|
result, err := reduceRetrieveResults(context.Background(),
|
|
[]*internalpb.RetrieveResults{r1, r2},
|
|
&queryParams{limit: 2, reduceType: reduce.IReduceInOrderForBest})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
assert.Equal(t, []int64{11, 11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
len := len(result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
assert.InDeltaSlice(t, resultFloat[0:(len)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
|
|
})
|
|
|
|
t.Run("test stop reduce for best for limit and offset", func(t *testing.T) {
|
|
r1.HasMoreResult = true
|
|
r2.HasMoreResult = true
|
|
result, err := reduceRetrieveResults(context.Background(),
|
|
[]*internalpb.RetrieveResults{r1, r2},
|
|
&queryParams{limit: 1, offset: 1, reduceType: reduce.IReduceInOrderForBest})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
assert.Equal(t, []int64{11, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
})
|
|
|
|
t.Run("test stop reduce for best for limit and offset", func(t *testing.T) {
|
|
r1.HasMoreResult = false
|
|
r2.HasMoreResult = true
|
|
result, err := reduceRetrieveResults(context.Background(),
|
|
[]*internalpb.RetrieveResults{r1, r2},
|
|
&queryParams{limit: 2, offset: 1, reduceType: reduce.IReduceInOrderForBest})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
|
|
// we should get 6 result back in total, but only get 4, which means all the result should actually be part of result
|
|
assert.Equal(t, []int64{11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
})
|
|
|
|
t.Run("test stop reduce for best for unlimited set", func(t *testing.T) {
|
|
r1.HasMoreResult = false
|
|
r2.HasMoreResult = false
|
|
result, err := reduceRetrieveResults(context.Background(),
|
|
[]*internalpb.RetrieveResults{r1, r2},
|
|
&queryParams{limit: typeutil.Unlimited, reduceType: reduce.IReduceInOrderForBest})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
assert.Equal(t, []int64{11, 11, 22, 22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
len := len(result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
assert.InDeltaSlice(t, resultFloat[0:(len)*Dim], result.FieldsData[1].GetVectors().GetFloatVector().Data, 10e-10)
|
|
})
|
|
|
|
t.Run("test stop reduce for best for unlimited set amd offset", func(t *testing.T) {
|
|
result, err := reduceRetrieveResults(context.Background(),
|
|
[]*internalpb.RetrieveResults{r1, r2},
|
|
&queryParams{limit: typeutil.Unlimited, offset: 3, reduceType: reduce.IReduceInOrderForBest})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
assert.Equal(t, []int64{22}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
})
|
|
t.Run("test iterator without setting reduce stop for best", func(t *testing.T) {
|
|
r1.HasMoreResult = true
|
|
r2.HasMoreResult = true
|
|
result, err := reduceRetrieveResults(context.Background(),
|
|
[]*internalpb.RetrieveResults{r1, r2},
|
|
&queryParams{limit: 1, reduceType: reduce.IReduceInOrder})
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, 2, len(result.GetFieldsData()))
|
|
assert.Equal(t, []int64{11}, result.GetFieldsData()[0].GetScalars().GetLongData().Data)
|
|
})
|
|
})
|
|
})
|
|
}
|
|
|
|
func getFieldData(fieldName string, fieldID int64, fieldType schemapb.DataType, fieldValue interface{}, dim int64) *schemapb.FieldData {
|
|
var fieldData *schemapb.FieldData
|
|
switch fieldType {
|
|
case schemapb.DataType_Bool:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_Bool,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_BoolData{
|
|
BoolData: &schemapb.BoolArray{
|
|
Data: fieldValue.([]bool),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
case schemapb.DataType_Int32:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_Int32,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_IntData{
|
|
IntData: &schemapb.IntArray{
|
|
Data: fieldValue.([]int32),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
case schemapb.DataType_Int64:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_Int64,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: fieldValue.([]int64),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
case schemapb.DataType_Float:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_Float,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_FloatData{
|
|
FloatData: &schemapb.FloatArray{
|
|
Data: fieldValue.([]float32),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
case schemapb.DataType_Double:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_Double,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_DoubleData{
|
|
DoubleData: &schemapb.DoubleArray{
|
|
Data: fieldValue.([]float64),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
case schemapb.DataType_VarChar:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_VarChar,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: fieldValue.([]string),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
|
|
case schemapb.DataType_BinaryVector:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_BinaryVector,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: dim,
|
|
Data: &schemapb.VectorField_BinaryVector{
|
|
BinaryVector: fieldValue.([]byte),
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
case schemapb.DataType_FloatVector:
|
|
fieldData = &schemapb.FieldData{
|
|
Type: schemapb.DataType_FloatVector,
|
|
FieldName: fieldName,
|
|
Field: &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Dim: dim,
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: fieldValue.([]float32),
|
|
},
|
|
},
|
|
},
|
|
},
|
|
FieldId: fieldID,
|
|
}
|
|
default:
|
|
log.Warn("not supported field type", zap.String("fieldType", fieldType.String()))
|
|
}
|
|
|
|
return fieldData
|
|
}
|
|
|
|
func Test_filterSystemFields(t *testing.T) {
|
|
outputFieldIDs := []UniqueID{common.RowIDField, common.TimeStampField, common.StartOfUserFieldID}
|
|
filtered := filterSystemFields(outputFieldIDs)
|
|
assert.ElementsMatch(t, []UniqueID{common.StartOfUserFieldID}, filtered)
|
|
}
|
|
|
|
func Test_matchCountRule(t *testing.T) {
|
|
type args struct {
|
|
outputs []string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want bool
|
|
}{
|
|
{
|
|
args: args{
|
|
outputs: nil,
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{"count(*)", "count(*)"},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{"not count(*)"},
|
|
},
|
|
want: false,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{"count(*)"},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{"COUNT(*)"},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{" count(*)"},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{" COUNT(*)"},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{"count(*) "},
|
|
},
|
|
want: true,
|
|
},
|
|
{
|
|
args: args{
|
|
outputs: []string{"COUNT(*) "},
|
|
},
|
|
want: true,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert.Equalf(t, tt.want, matchCountRule(tt.args.outputs), "matchCountRule(%v)", tt.args.outputs)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_createCntPlan(t *testing.T) {
|
|
t.Run("plan without filter", func(t *testing.T) {
|
|
plan, err := createCntPlan("", nil, nil)
|
|
assert.NoError(t, err)
|
|
assert.True(t, plan.GetQuery().GetIsCount())
|
|
assert.Nil(t, plan.GetQuery().GetPredicates())
|
|
})
|
|
|
|
t.Run("invalid schema", func(t *testing.T) {
|
|
schema := &schemapb.CollectionSchema{
|
|
Fields: []*schemapb.FieldSchema{
|
|
{
|
|
FieldID: 100,
|
|
Name: "a",
|
|
IsPrimaryKey: true,
|
|
DataType: schemapb.DataType_Int64,
|
|
},
|
|
},
|
|
}
|
|
schemaHelper, err := typeutil.CreateSchemaHelper(schema)
|
|
require.NoError(t, err)
|
|
plan, err := createCntPlan("a > 4", schemaHelper, nil)
|
|
assert.NoError(t, err)
|
|
assert.True(t, plan.GetQuery().GetIsCount())
|
|
assert.NotNil(t, plan.GetQuery().GetPredicates())
|
|
})
|
|
}
|
|
|
|
func Test_queryTask_createPlan(t *testing.T) {
|
|
collSchema := newTestSchema()
|
|
t.Run("match count rule", func(t *testing.T) {
|
|
schema := newSchemaInfo(collSchema)
|
|
tsk := &queryTask{
|
|
request: &milvuspb.QueryRequest{
|
|
OutputFields: []string{"count(*)"},
|
|
},
|
|
schema: schema,
|
|
}
|
|
err := tsk.createPlan(context.TODO())
|
|
assert.NoError(t, err)
|
|
plan := tsk.plan
|
|
assert.True(t, plan.GetQuery().GetIsCount())
|
|
assert.Nil(t, plan.GetQuery().GetPredicates())
|
|
})
|
|
|
|
t.Run("query without expression", func(t *testing.T) {
|
|
schema := newSchemaInfo(collSchema)
|
|
tsk := &queryTask{
|
|
request: &milvuspb.QueryRequest{
|
|
OutputFields: []string{"Int64"},
|
|
},
|
|
schema: schema,
|
|
}
|
|
err := tsk.createPlan(context.TODO())
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("invalid expression", func(t *testing.T) {
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
tsk := &queryTask{
|
|
schema: schema,
|
|
request: &milvuspb.QueryRequest{
|
|
OutputFields: []string{"a"},
|
|
Expr: "b > 2",
|
|
},
|
|
}
|
|
err := tsk.createPlan(context.TODO())
|
|
assert.Error(t, err)
|
|
})
|
|
|
|
t.Run("invalid output fields", func(t *testing.T) {
|
|
schema := newSchemaInfo(collSchema)
|
|
|
|
tsk := &queryTask{
|
|
schema: schema,
|
|
request: &milvuspb.QueryRequest{
|
|
OutputFields: []string{"b"},
|
|
Expr: "a > 2",
|
|
},
|
|
}
|
|
err := tsk.createPlan(context.TODO())
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestQueryTask_IDs2Expr(t *testing.T) {
|
|
fieldName := "pk"
|
|
intIDs := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_IntId{
|
|
IntId: &schemapb.LongArray{
|
|
Data: []int64{1, 2, 3, 4, 5},
|
|
},
|
|
},
|
|
}
|
|
stringIDs := &schemapb.IDs{
|
|
IdField: &schemapb.IDs_StrId{
|
|
StrId: &schemapb.StringArray{
|
|
Data: []string{"a", "b", "c"},
|
|
},
|
|
},
|
|
}
|
|
idExpr := IDs2Expr(fieldName, intIDs)
|
|
expectIDExpr := "pk in [ 1, 2, 3, 4, 5 ]"
|
|
assert.Equal(t, expectIDExpr, idExpr)
|
|
|
|
strExpr := IDs2Expr(fieldName, stringIDs)
|
|
expectStrExpr := "pk in [ \"a\", \"b\", \"c\" ]"
|
|
assert.Equal(t, expectStrExpr, strExpr)
|
|
}
|
|
|
|
func TestQueryTask_CanSkipAllocTimestamp(t *testing.T) {
|
|
dbName := "test_query"
|
|
collName := "test_skip_alloc_timestamp"
|
|
collID := UniqueID(111)
|
|
mockMetaCache := NewMockCache(t)
|
|
globalMetaCache = mockMetaCache
|
|
|
|
t.Run("default consistency level", func(t *testing.T) {
|
|
qt := &queryTask{
|
|
request: &milvuspb.QueryRequest{
|
|
Base: nil,
|
|
DbName: dbName,
|
|
CollectionName: collName,
|
|
UseDefaultConsistency: true,
|
|
},
|
|
}
|
|
mockMetaCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collID, nil)
|
|
mockMetaCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
|
|
&collectionInfo{
|
|
collID: collID,
|
|
consistencyLevel: commonpb.ConsistencyLevel_Eventually,
|
|
}, nil).Once()
|
|
|
|
skip := qt.CanSkipAllocTimestamp()
|
|
assert.True(t, skip)
|
|
|
|
mockMetaCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
|
|
&collectionInfo{
|
|
collID: collID,
|
|
consistencyLevel: commonpb.ConsistencyLevel_Bounded,
|
|
}, nil).Once()
|
|
skip = qt.CanSkipAllocTimestamp()
|
|
assert.True(t, skip)
|
|
|
|
mockMetaCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
|
|
&collectionInfo{
|
|
collID: collID,
|
|
consistencyLevel: commonpb.ConsistencyLevel_Strong,
|
|
}, nil).Once()
|
|
skip = qt.CanSkipAllocTimestamp()
|
|
assert.False(t, skip)
|
|
})
|
|
|
|
t.Run("request consistency level", func(t *testing.T) {
|
|
mockMetaCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
|
|
&collectionInfo{
|
|
collID: collID,
|
|
consistencyLevel: commonpb.ConsistencyLevel_Eventually,
|
|
}, nil).Times(3)
|
|
|
|
qt := &queryTask{
|
|
request: &milvuspb.QueryRequest{
|
|
Base: nil,
|
|
DbName: dbName,
|
|
CollectionName: collName,
|
|
UseDefaultConsistency: false,
|
|
ConsistencyLevel: commonpb.ConsistencyLevel_Eventually,
|
|
},
|
|
}
|
|
|
|
skip := qt.CanSkipAllocTimestamp()
|
|
assert.True(t, skip)
|
|
|
|
qt.request.ConsistencyLevel = commonpb.ConsistencyLevel_Bounded
|
|
skip = qt.CanSkipAllocTimestamp()
|
|
assert.True(t, skip)
|
|
|
|
qt.request.ConsistencyLevel = commonpb.ConsistencyLevel_Strong
|
|
skip = qt.CanSkipAllocTimestamp()
|
|
assert.False(t, skip)
|
|
})
|
|
|
|
t.Run("legacy_guarantee_ts", func(t *testing.T) {
|
|
qt := &queryTask{
|
|
request: &milvuspb.QueryRequest{
|
|
Base: nil,
|
|
DbName: dbName,
|
|
CollectionName: collName,
|
|
UseDefaultConsistency: false,
|
|
ConsistencyLevel: commonpb.ConsistencyLevel_Strong,
|
|
},
|
|
}
|
|
|
|
skip := qt.CanSkipAllocTimestamp()
|
|
assert.False(t, skip)
|
|
|
|
qt.request.GuaranteeTimestamp = 1 // eventually
|
|
skip = qt.CanSkipAllocTimestamp()
|
|
assert.True(t, skip)
|
|
|
|
qt.request.GuaranteeTimestamp = 2 // bounded
|
|
skip = qt.CanSkipAllocTimestamp()
|
|
assert.True(t, skip)
|
|
})
|
|
|
|
t.Run("failed", func(t *testing.T) {
|
|
mockMetaCache.ExpectedCalls = nil
|
|
mockMetaCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collID, nil)
|
|
mockMetaCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
|
|
nil, errors.New("mock error")).Once()
|
|
|
|
qt := &queryTask{
|
|
request: &milvuspb.QueryRequest{
|
|
Base: nil,
|
|
DbName: dbName,
|
|
CollectionName: collName,
|
|
UseDefaultConsistency: true,
|
|
ConsistencyLevel: commonpb.ConsistencyLevel_Eventually,
|
|
},
|
|
}
|
|
|
|
skip := qt.CanSkipAllocTimestamp()
|
|
assert.False(t, skip)
|
|
|
|
mockMetaCache.ExpectedCalls = nil
|
|
mockMetaCache.EXPECT().GetCollectionID(mock.Anything, mock.Anything, mock.Anything).Return(collID, errors.New("mock error"))
|
|
mockMetaCache.EXPECT().GetCollectionInfo(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(
|
|
&collectionInfo{
|
|
collID: collID,
|
|
consistencyLevel: commonpb.ConsistencyLevel_Eventually,
|
|
}, nil)
|
|
|
|
skip = qt.CanSkipAllocTimestamp()
|
|
assert.False(t, skip)
|
|
|
|
qt2 := &queryTask{
|
|
request: &milvuspb.QueryRequest{
|
|
Base: nil,
|
|
DbName: dbName,
|
|
CollectionName: collName,
|
|
UseDefaultConsistency: false,
|
|
ConsistencyLevel: commonpb.ConsistencyLevel_Eventually,
|
|
},
|
|
}
|
|
|
|
skip = qt2.CanSkipAllocTimestamp()
|
|
assert.True(t, skip)
|
|
})
|
|
}
|