mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 01:58:34 +08:00
Coverage error case of searchTask.PreExecute (#7826)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
664d4ca923
commit
bcf9b4e240
@ -41,11 +41,6 @@ type getChannelsService interface {
|
||||
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
|
||||
}
|
||||
|
||||
// queryCoordShowCollectionsInterface used in searchTask & queryTask
|
||||
type queryCoordShowCollectionsInterface interface {
|
||||
ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
|
||||
}
|
||||
|
||||
// createQueryChannelInterface defines CreateQueryChannel
|
||||
type createQueryChannelInterface interface {
|
||||
CreateQueryChannel(ctx context.Context, request *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error)
|
||||
|
||||
@ -17,8 +17,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
@ -555,33 +553,3 @@ func generateHashKeys(numRows int) []uint32 {
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
type mockQueryCoordShowCollectionsInterface struct {
|
||||
collectionIDs []int64
|
||||
inMemoryPercentages []int64
|
||||
}
|
||||
|
||||
func (ins *mockQueryCoordShowCollectionsInterface) ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
resp := &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
CollectionIDs: ins.collectionIDs,
|
||||
InMemoryPercentages: ins.inMemoryPercentages,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (ins *mockQueryCoordShowCollectionsInterface) addCollection(collectionID int64, inMemoryPercentage int64) {
|
||||
ins.collectionIDs = append(ins.collectionIDs, collectionID)
|
||||
ins.inMemoryPercentages = append(ins.inMemoryPercentages, collectionID)
|
||||
}
|
||||
|
||||
func newMockQueryCoordShowCollectionsInterface() *mockQueryCoordShowCollectionsInterface {
|
||||
return &mockQueryCoordShowCollectionsInterface{
|
||||
collectionIDs: make([]int64, 0),
|
||||
inMemoryPercentages: make([]int64, 0),
|
||||
}
|
||||
}
|
||||
|
||||
312
internal/proxy/query_coord_mock_test.go
Normal file
312
internal/proxy/query_coord_mock_test.go
Normal file
@ -0,0 +1,312 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||
)
|
||||
|
||||
type QueryCoordMockOption func(mock *QueryCoordMock)
|
||||
|
||||
type queryCoordShowCollectionsFuncType func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
|
||||
|
||||
func SetQueryCoordShowCollectionsFunc(f queryCoordShowCollectionsFuncType) QueryCoordMockOption {
|
||||
return func(mock *QueryCoordMock) {
|
||||
mock.showCollectionsFunc = f
|
||||
}
|
||||
}
|
||||
|
||||
type QueryCoordMock struct {
|
||||
nodeID typeutil.UniqueID
|
||||
address string
|
||||
|
||||
state atomic.Value // internal.StateCode
|
||||
|
||||
collectionIDs []int64
|
||||
inMemoryPercentages []int64
|
||||
colMtx sync.RWMutex
|
||||
|
||||
showCollectionsFunc queryCoordShowCollectionsFuncType
|
||||
|
||||
statisticsChannel string
|
||||
timeTickChannel string
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) updateState(state internalpb.StateCode) {
|
||||
coord.state.Store(state)
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) getState() internalpb.StateCode {
|
||||
return coord.state.Load().(internalpb.StateCode)
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) healthy() bool {
|
||||
return coord.getState() == internalpb.StateCode_Healthy
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) Init() error {
|
||||
coord.updateState(internalpb.StateCode_Initializing)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) Start() error {
|
||||
defer coord.updateState(internalpb.StateCode_Healthy)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) Stop() error {
|
||||
defer coord.updateState(internalpb.StateCode_Abnormal)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||
return &internalpb.ComponentStates{
|
||||
State: &internalpb.ComponentInfo{
|
||||
NodeID: coord.nodeID,
|
||||
Role: typeutil.QueryCoordRole,
|
||||
StateCode: coord.getState(),
|
||||
ExtraInfo: nil,
|
||||
},
|
||||
SubcomponentStates: nil,
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: coord.statisticsChannel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) Register() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||
return &milvuspb.StringResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
Value: coord.timeTickChannel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) ResetShowCollectionsFunc() {
|
||||
coord.showCollectionsFunc = nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) SetShowCollectionsFunc(f queryCoordShowCollectionsFuncType) {
|
||||
coord.showCollectionsFunc = f
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
if !coord.healthy() {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
if coord.showCollectionsFunc != nil {
|
||||
return coord.showCollectionsFunc(ctx, req)
|
||||
}
|
||||
|
||||
coord.colMtx.RLock()
|
||||
defer coord.colMtx.RUnlock()
|
||||
|
||||
resp := &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
},
|
||||
CollectionIDs: coord.collectionIDs,
|
||||
InMemoryPercentages: coord.inMemoryPercentages,
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
|
||||
if !coord.healthy() {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
}, nil
|
||||
}
|
||||
|
||||
coord.colMtx.Lock()
|
||||
defer coord.colMtx.Unlock()
|
||||
|
||||
for _, colID := range coord.collectionIDs {
|
||||
if req.CollectionID == colID {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: fmt.Sprintf("collection %v already loaded", req.CollectionID),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
coord.collectionIDs = append(coord.collectionIDs, req.CollectionID)
|
||||
coord.inMemoryPercentages = append(coord.inMemoryPercentages, 100)
|
||||
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||
if !coord.healthy() {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
}, nil
|
||||
}
|
||||
|
||||
coord.colMtx.Lock()
|
||||
defer coord.colMtx.Unlock()
|
||||
|
||||
for i := len(coord.collectionIDs) - 1; i >= 0; i-- {
|
||||
if req.CollectionID == coord.collectionIDs[i] {
|
||||
coord.collectionIDs = append(coord.collectionIDs[:i], coord.collectionIDs[i+1:]...)
|
||||
coord.inMemoryPercentages = append(coord.inMemoryPercentages[:i], coord.inMemoryPercentages[i+1:]...)
|
||||
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_Success,
|
||||
Reason: "",
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: fmt.Sprintf("collection %v not loaded", req.CollectionID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
|
||||
if !coord.healthy() {
|
||||
return &querypb.ShowPartitionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||
if !coord.healthy() {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
}, nil
|
||||
}
|
||||
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||
if !coord.healthy() {
|
||||
return &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
}, nil
|
||||
}
|
||||
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) CreateQueryChannel(ctx context.Context, req *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error) {
|
||||
if !coord.healthy() {
|
||||
return &querypb.CreateQueryChannelResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
|
||||
if !coord.healthy() {
|
||||
return &querypb.GetPartitionStatesResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
|
||||
if !coord.healthy() {
|
||||
return &querypb.GetSegmentInfoResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (coord *QueryCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
|
||||
if !coord.healthy() {
|
||||
return &milvuspb.GetMetricsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "unhealthy",
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func NewQueryCoordMock(opts ...QueryCoordMockOption) *QueryCoordMock {
|
||||
coord := &QueryCoordMock{
|
||||
nodeID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
|
||||
address: funcutil.GenRandomStr(), // TODO(dragondriver): random address
|
||||
state: atomic.Value{},
|
||||
collectionIDs: make([]int64, 0),
|
||||
inMemoryPercentages: make([]int64, 0),
|
||||
colMtx: sync.RWMutex{},
|
||||
statisticsChannel: funcutil.GenRandomStr(),
|
||||
timeTickChannel: funcutil.GenRandomStr(),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(coord)
|
||||
}
|
||||
|
||||
return coord
|
||||
}
|
||||
@ -73,11 +73,7 @@ type RootCoordMockOption func(mock *RootCoordMock)
|
||||
|
||||
type describeCollectionFuncType func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
|
||||
|
||||
func SetDescribeCollectionFunc(f describeCollectionFuncType) RootCoordMockOption {
|
||||
return func(mock *RootCoordMock) {
|
||||
mock.SetDescribeCollectionFunc(f)
|
||||
}
|
||||
}
|
||||
type showPartitionsFuncType func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error)
|
||||
|
||||
type RootCoordMock struct {
|
||||
nodeID typeutil.UniqueID
|
||||
@ -98,6 +94,7 @@ type RootCoordMock struct {
|
||||
partitionMtx sync.RWMutex
|
||||
|
||||
describeCollectionFunc describeCollectionFuncType
|
||||
showPartitionsFunc showPartitionsFuncType
|
||||
|
||||
// TODO(dragondriver): index-related
|
||||
|
||||
@ -332,7 +329,7 @@ func (coord *RootCoordMock) SetDescribeCollectionFunc(f describeCollectionFuncTy
|
||||
coord.describeCollectionFunc = f
|
||||
}
|
||||
|
||||
func (coord *RootCoordMock) ResetDescribeCollectionFunc(f describeCollectionFuncType) {
|
||||
func (coord *RootCoordMock) ResetDescribeCollectionFunc() {
|
||||
coord.describeCollectionFunc = nil
|
||||
}
|
||||
|
||||
@ -560,6 +557,11 @@ func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.Sh
|
||||
PartitionIDs: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if coord.showPartitionsFunc != nil {
|
||||
return coord.showPartitionsFunc(ctx, req)
|
||||
}
|
||||
|
||||
coord.collMtx.RLock()
|
||||
defer coord.collMtx.RUnlock()
|
||||
|
||||
|
||||
@ -1350,7 +1350,7 @@ type searchTask struct {
|
||||
result *milvuspb.SearchResults
|
||||
query *milvuspb.SearchRequest
|
||||
chMgr channelsMgr
|
||||
qc queryCoordShowCollectionsInterface
|
||||
qc types.QueryCoord
|
||||
}
|
||||
|
||||
func (st *searchTask) TraceCtx() context.Context {
|
||||
@ -1398,6 +1398,14 @@ func (st *searchTask) getChannels() ([]pChan, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = st.chMgr.getChannels(collID)
|
||||
if err != nil {
|
||||
err := st.chMgr.createDMLMsgStream(collID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return st.chMgr.getChannels(collID)
|
||||
}
|
||||
|
||||
@ -1474,10 +1482,7 @@ func (st *searchTask) PreExecute(ctx context.Context) error {
|
||||
|
||||
st.Base.MsgType = commonpb.MsgType_Search
|
||||
|
||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||
if err != nil { // err is not nil if collection not exists
|
||||
return err
|
||||
}
|
||||
schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||
|
||||
outputFields, err := translateOutputFields(st.query.OutputFields, schema, false)
|
||||
if err != nil {
|
||||
@ -1575,11 +1580,7 @@ func (st *searchTask) PreExecute(ctx context.Context) error {
|
||||
|
||||
st.SearchRequest.ResultChannelID = Params.SearchResultChannelNames[0]
|
||||
st.SearchRequest.DbID = 0 // todo
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
if err != nil { // err is not nil if collection not exists
|
||||
return err
|
||||
}
|
||||
st.SearchRequest.CollectionID = collectionID
|
||||
st.SearchRequest.CollectionID = collID
|
||||
st.SearchRequest.PartitionIDs = make([]UniqueID, 0)
|
||||
|
||||
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
@ -12,6 +13,8 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/msgstream"
|
||||
|
||||
"github.com/milvus-io/milvus/internal/util/distance"
|
||||
@ -1467,8 +1470,22 @@ func TestSearchTask_all(t *testing.T) {
|
||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
assert.NoError(t, err)
|
||||
|
||||
qc := newMockQueryCoordShowCollectionsInterface()
|
||||
qc.addCollection(collectionID, 100)
|
||||
qc := NewQueryCoordMock()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: Params.ProxyID,
|
||||
},
|
||||
DbID: 0,
|
||||
CollectionID: collectionID,
|
||||
Schema: nil,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||
|
||||
req := constructSearchRequest(dbName, collectionName,
|
||||
expr,
|
||||
@ -1604,3 +1621,340 @@ func TestSearchTask_all(t *testing.T) {
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestSearchTask_Type(t *testing.T) {
|
||||
Params.Init()
|
||||
|
||||
task := &searchTask{
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: nil,
|
||||
},
|
||||
}
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
assert.Equal(t, commonpb.MsgType_Search, task.Type())
|
||||
}
|
||||
|
||||
func TestSearchTask_Ts(t *testing.T) {
|
||||
Params.Init()
|
||||
|
||||
task := &searchTask{
|
||||
SearchRequest: &internalpb.SearchRequest{
|
||||
Base: nil,
|
||||
},
|
||||
}
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
|
||||
ts := Timestamp(time.Now().Nanosecond())
|
||||
task.SetTs(ts)
|
||||
assert.Equal(t, ts, task.BeginTs())
|
||||
assert.Equal(t, ts, task.EndTs())
|
||||
}
|
||||
|
||||
func TestSearchTask_Channels(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Params.Init()
|
||||
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = InitMetaCache(rc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
query := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
defer chMgr.removeAllDQLStream()
|
||||
|
||||
prefix := "TestSearchTask_Channels"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
shardsNum := int32(2)
|
||||
dbName := ""
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
query: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
chMgr: chMgr,
|
||||
}
|
||||
|
||||
// collection not exist
|
||||
_, err = task.getVChannels()
|
||||
assert.Error(t, err)
|
||||
_, err = task.getVChannels()
|
||||
assert.Error(t, err)
|
||||
|
||||
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
_, err = task.getChannels()
|
||||
assert.NoError(t, err)
|
||||
_, err = task.getVChannels()
|
||||
assert.NoError(t, err)
|
||||
|
||||
_ = chMgr.removeAllDMLStream()
|
||||
chMgr.dmlChannelsMgr.getChannelsFunc = func(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
_, err = task.getChannels()
|
||||
assert.Error(t, err)
|
||||
_, err = task.getVChannels()
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSearchTask_PreExecute(t *testing.T) {
|
||||
var err error
|
||||
|
||||
Params.Init()
|
||||
Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
|
||||
|
||||
rc := NewRootCoordMock()
|
||||
rc.Start()
|
||||
defer rc.Stop()
|
||||
|
||||
qc := NewQueryCoordMock()
|
||||
qc.Start()
|
||||
defer qc.Stop()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = InitMetaCache(rc)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||
query := newMockGetChannelsService()
|
||||
factory := newSimpleMockMsgStreamFactory()
|
||||
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
|
||||
defer chMgr.removeAllDMLStream()
|
||||
defer chMgr.removeAllDQLStream()
|
||||
|
||||
prefix := "TestSearchTask_PreExecute"
|
||||
collectionName := prefix + funcutil.GenRandomStr()
|
||||
shardsNum := int32(2)
|
||||
dbName := ""
|
||||
int64Field := "int64"
|
||||
floatVecField := "fvec"
|
||||
dim := 128
|
||||
|
||||
task := &searchTask{
|
||||
ctx: ctx,
|
||||
SearchRequest: &internalpb.SearchRequest{},
|
||||
query: &milvuspb.SearchRequest{
|
||||
CollectionName: collectionName,
|
||||
},
|
||||
chMgr: chMgr,
|
||||
qc: qc,
|
||||
}
|
||||
assert.NoError(t, task.OnEnqueue())
|
||||
|
||||
// collection not exist
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
|
||||
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
|
||||
marshaledSchema, err := proto.Marshal(schema)
|
||||
assert.NoError(t, err)
|
||||
|
||||
createColT := &createCollectionTask{
|
||||
Condition: NewTaskCondition(ctx),
|
||||
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||
Base: nil,
|
||||
DbName: dbName,
|
||||
CollectionName: collectionName,
|
||||
Schema: marshaledSchema,
|
||||
ShardsNum: shardsNum,
|
||||
},
|
||||
ctx: ctx,
|
||||
rootCoord: rc,
|
||||
result: nil,
|
||||
schema: nil,
|
||||
}
|
||||
|
||||
assert.NoError(t, createColT.OnEnqueue())
|
||||
assert.NoError(t, createColT.PreExecute(ctx))
|
||||
assert.NoError(t, createColT.Execute(ctx))
|
||||
assert.NoError(t, createColT.PostExecute(ctx))
|
||||
|
||||
collectionID, _ := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||
|
||||
// ValidateCollectionName
|
||||
task.query.CollectionName = "$"
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.CollectionName = collectionName
|
||||
|
||||
// Validate Partition
|
||||
task.query.PartitionNames = []string{"$"}
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.PartitionNames = nil
|
||||
|
||||
// mock show collections of query coord
|
||||
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return nil, errors.New("mock")
|
||||
})
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||
return &querypb.ShowCollectionsResponse{
|
||||
Status: &commonpb.Status{
|
||||
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||
Reason: "mock",
|
||||
},
|
||||
}, nil
|
||||
})
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
qc.ResetShowCollectionsFunc()
|
||||
|
||||
// collection not loaded
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
_, _ = qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||
Base: &commonpb.MsgBase{
|
||||
MsgType: commonpb.MsgType_LoadCollection,
|
||||
MsgID: 0,
|
||||
Timestamp: 0,
|
||||
SourceID: 0,
|
||||
},
|
||||
DbID: 0,
|
||||
CollectionID: collectionID,
|
||||
Schema: nil,
|
||||
})
|
||||
|
||||
// no anns field
|
||||
task.query.DslType = commonpb.DslType_BoolExprV1
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
}
|
||||
|
||||
// no topk
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "invalid",
|
||||
},
|
||||
}
|
||||
|
||||
// invalid topk
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
}
|
||||
|
||||
// no metric type
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
}
|
||||
|
||||
// no search params
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: int64Field,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10}`,
|
||||
},
|
||||
}
|
||||
|
||||
// failed to create query plan
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: floatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: MetricTypeKey,
|
||||
Value: distance.L2,
|
||||
},
|
||||
{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10}`,
|
||||
},
|
||||
}
|
||||
|
||||
// field not exist
|
||||
task.query.OutputFields = []string{int64Field + funcutil.GenRandomStr()}
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
// contain vector field
|
||||
task.query.OutputFields = []string{floatVecField}
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
task.query.OutputFields = []string{int64Field}
|
||||
|
||||
// partition
|
||||
rc.showPartitionsFunc = func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
|
||||
return nil, errors.New("mock")
|
||||
}
|
||||
assert.Error(t, task.PreExecute(ctx))
|
||||
rc.showPartitionsFunc = nil
|
||||
|
||||
// TODO(dragondriver): test partition-related error
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user