mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 10:08:42 +08:00
Coverage error case of searchTask.PreExecute (#7826)
Signed-off-by: dragondriver <jiquan.long@zilliz.com>
This commit is contained in:
parent
664d4ca923
commit
bcf9b4e240
@ -41,11 +41,6 @@ type getChannelsService interface {
|
|||||||
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
|
GetChannels(collectionID UniqueID) (map[vChan]pChan, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// queryCoordShowCollectionsInterface used in searchTask & queryTask
|
|
||||||
type queryCoordShowCollectionsInterface interface {
|
|
||||||
ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// createQueryChannelInterface defines CreateQueryChannel
|
// createQueryChannelInterface defines CreateQueryChannel
|
||||||
type createQueryChannelInterface interface {
|
type createQueryChannelInterface interface {
|
||||||
CreateQueryChannel(ctx context.Context, request *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error)
|
CreateQueryChannel(ctx context.Context, request *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error)
|
||||||
|
|||||||
@ -17,8 +17,6 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/proto/querypb"
|
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
"github.com/milvus-io/milvus/internal/proto/schemapb"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/msgstream"
|
"github.com/milvus-io/milvus/internal/msgstream"
|
||||||
@ -555,33 +553,3 @@ func generateHashKeys(numRows int) []uint32 {
|
|||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockQueryCoordShowCollectionsInterface struct {
|
|
||||||
collectionIDs []int64
|
|
||||||
inMemoryPercentages []int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ins *mockQueryCoordShowCollectionsInterface) ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
|
||||||
resp := &querypb.ShowCollectionsResponse{
|
|
||||||
Status: &commonpb.Status{
|
|
||||||
ErrorCode: commonpb.ErrorCode_Success,
|
|
||||||
Reason: "",
|
|
||||||
},
|
|
||||||
CollectionIDs: ins.collectionIDs,
|
|
||||||
InMemoryPercentages: ins.inMemoryPercentages,
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ins *mockQueryCoordShowCollectionsInterface) addCollection(collectionID int64, inMemoryPercentage int64) {
|
|
||||||
ins.collectionIDs = append(ins.collectionIDs, collectionID)
|
|
||||||
ins.inMemoryPercentages = append(ins.inMemoryPercentages, collectionID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func newMockQueryCoordShowCollectionsInterface() *mockQueryCoordShowCollectionsInterface {
|
|
||||||
return &mockQueryCoordShowCollectionsInterface{
|
|
||||||
collectionIDs: make([]int64, 0),
|
|
||||||
inMemoryPercentages: make([]int64, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
312
internal/proxy/query_coord_mock_test.go
Normal file
312
internal/proxy/query_coord_mock_test.go
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/uniquegenerator"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/commonpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/internalpb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
|
"github.com/milvus-io/milvus/internal/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type QueryCoordMockOption func(mock *QueryCoordMock)
|
||||||
|
|
||||||
|
type queryCoordShowCollectionsFuncType func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
|
||||||
|
|
||||||
|
func SetQueryCoordShowCollectionsFunc(f queryCoordShowCollectionsFuncType) QueryCoordMockOption {
|
||||||
|
return func(mock *QueryCoordMock) {
|
||||||
|
mock.showCollectionsFunc = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type QueryCoordMock struct {
|
||||||
|
nodeID typeutil.UniqueID
|
||||||
|
address string
|
||||||
|
|
||||||
|
state atomic.Value // internal.StateCode
|
||||||
|
|
||||||
|
collectionIDs []int64
|
||||||
|
inMemoryPercentages []int64
|
||||||
|
colMtx sync.RWMutex
|
||||||
|
|
||||||
|
showCollectionsFunc queryCoordShowCollectionsFuncType
|
||||||
|
|
||||||
|
statisticsChannel string
|
||||||
|
timeTickChannel string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) updateState(state internalpb.StateCode) {
|
||||||
|
coord.state.Store(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) getState() internalpb.StateCode {
|
||||||
|
return coord.state.Load().(internalpb.StateCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) healthy() bool {
|
||||||
|
return coord.getState() == internalpb.StateCode_Healthy
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) Init() error {
|
||||||
|
coord.updateState(internalpb.StateCode_Initializing)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) Start() error {
|
||||||
|
defer coord.updateState(internalpb.StateCode_Healthy)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) Stop() error {
|
||||||
|
defer coord.updateState(internalpb.StateCode_Abnormal)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
|
||||||
|
return &internalpb.ComponentStates{
|
||||||
|
State: &internalpb.ComponentInfo{
|
||||||
|
NodeID: coord.nodeID,
|
||||||
|
Role: typeutil.QueryCoordRole,
|
||||||
|
StateCode: coord.getState(),
|
||||||
|
ExtraInfo: nil,
|
||||||
|
},
|
||||||
|
SubcomponentStates: nil,
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
Reason: "",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||||
|
return &milvuspb.StringResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
Reason: "",
|
||||||
|
},
|
||||||
|
Value: coord.statisticsChannel,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) Register() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
|
||||||
|
return &milvuspb.StringResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
Reason: "",
|
||||||
|
},
|
||||||
|
Value: coord.timeTickChannel,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) ResetShowCollectionsFunc() {
|
||||||
|
coord.showCollectionsFunc = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) SetShowCollectionsFunc(f queryCoordShowCollectionsFuncType) {
|
||||||
|
coord.showCollectionsFunc = f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &querypb.ShowCollectionsResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if coord.showCollectionsFunc != nil {
|
||||||
|
return coord.showCollectionsFunc(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
coord.colMtx.RLock()
|
||||||
|
defer coord.colMtx.RUnlock()
|
||||||
|
|
||||||
|
resp := &querypb.ShowCollectionsResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
Reason: "",
|
||||||
|
},
|
||||||
|
CollectionIDs: coord.collectionIDs,
|
||||||
|
InMemoryPercentages: coord.inMemoryPercentages,
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
coord.colMtx.Lock()
|
||||||
|
defer coord.colMtx.Unlock()
|
||||||
|
|
||||||
|
for _, colID := range coord.collectionIDs {
|
||||||
|
if req.CollectionID == colID {
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: fmt.Sprintf("collection %v already loaded", req.CollectionID),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
coord.collectionIDs = append(coord.collectionIDs, req.CollectionID)
|
||||||
|
coord.inMemoryPercentages = append(coord.inMemoryPercentages, 100)
|
||||||
|
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
Reason: "",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
coord.colMtx.Lock()
|
||||||
|
defer coord.colMtx.Unlock()
|
||||||
|
|
||||||
|
for i := len(coord.collectionIDs) - 1; i >= 0; i-- {
|
||||||
|
if req.CollectionID == coord.collectionIDs[i] {
|
||||||
|
coord.collectionIDs = append(coord.collectionIDs[:i], coord.collectionIDs[i+1:]...)
|
||||||
|
coord.inMemoryPercentages = append(coord.inMemoryPercentages[:i], coord.inMemoryPercentages[i+1:]...)
|
||||||
|
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_Success,
|
||||||
|
Reason: "",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: fmt.Sprintf("collection %v not loaded", req.CollectionID),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &querypb.ShowPartitionsResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) CreateQueryChannel(ctx context.Context, req *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &querypb.CreateQueryChannelResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &querypb.GetPartitionStatesResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &querypb.GetSegmentInfoResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (coord *QueryCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
|
||||||
|
if !coord.healthy() {
|
||||||
|
return &milvuspb.GetMetricsResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "unhealthy",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueryCoordMock(opts ...QueryCoordMockOption) *QueryCoordMock {
|
||||||
|
coord := &QueryCoordMock{
|
||||||
|
nodeID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()),
|
||||||
|
address: funcutil.GenRandomStr(), // TODO(dragondriver): random address
|
||||||
|
state: atomic.Value{},
|
||||||
|
collectionIDs: make([]int64, 0),
|
||||||
|
inMemoryPercentages: make([]int64, 0),
|
||||||
|
colMtx: sync.RWMutex{},
|
||||||
|
statisticsChannel: funcutil.GenRandomStr(),
|
||||||
|
timeTickChannel: funcutil.GenRandomStr(),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(coord)
|
||||||
|
}
|
||||||
|
|
||||||
|
return coord
|
||||||
|
}
|
||||||
@ -73,11 +73,7 @@ type RootCoordMockOption func(mock *RootCoordMock)
|
|||||||
|
|
||||||
type describeCollectionFuncType func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
|
type describeCollectionFuncType func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error)
|
||||||
|
|
||||||
func SetDescribeCollectionFunc(f describeCollectionFuncType) RootCoordMockOption {
|
type showPartitionsFuncType func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error)
|
||||||
return func(mock *RootCoordMock) {
|
|
||||||
mock.SetDescribeCollectionFunc(f)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type RootCoordMock struct {
|
type RootCoordMock struct {
|
||||||
nodeID typeutil.UniqueID
|
nodeID typeutil.UniqueID
|
||||||
@ -98,6 +94,7 @@ type RootCoordMock struct {
|
|||||||
partitionMtx sync.RWMutex
|
partitionMtx sync.RWMutex
|
||||||
|
|
||||||
describeCollectionFunc describeCollectionFuncType
|
describeCollectionFunc describeCollectionFuncType
|
||||||
|
showPartitionsFunc showPartitionsFuncType
|
||||||
|
|
||||||
// TODO(dragondriver): index-related
|
// TODO(dragondriver): index-related
|
||||||
|
|
||||||
@ -332,7 +329,7 @@ func (coord *RootCoordMock) SetDescribeCollectionFunc(f describeCollectionFuncTy
|
|||||||
coord.describeCollectionFunc = f
|
coord.describeCollectionFunc = f
|
||||||
}
|
}
|
||||||
|
|
||||||
func (coord *RootCoordMock) ResetDescribeCollectionFunc(f describeCollectionFuncType) {
|
func (coord *RootCoordMock) ResetDescribeCollectionFunc() {
|
||||||
coord.describeCollectionFunc = nil
|
coord.describeCollectionFunc = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -560,6 +557,11 @@ func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.Sh
|
|||||||
PartitionIDs: nil,
|
PartitionIDs: nil,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if coord.showPartitionsFunc != nil {
|
||||||
|
return coord.showPartitionsFunc(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
coord.collMtx.RLock()
|
coord.collMtx.RLock()
|
||||||
defer coord.collMtx.RUnlock()
|
defer coord.collMtx.RUnlock()
|
||||||
|
|
||||||
|
|||||||
@ -1350,7 +1350,7 @@ type searchTask struct {
|
|||||||
result *milvuspb.SearchResults
|
result *milvuspb.SearchResults
|
||||||
query *milvuspb.SearchRequest
|
query *milvuspb.SearchRequest
|
||||||
chMgr channelsMgr
|
chMgr channelsMgr
|
||||||
qc queryCoordShowCollectionsInterface
|
qc types.QueryCoord
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *searchTask) TraceCtx() context.Context {
|
func (st *searchTask) TraceCtx() context.Context {
|
||||||
@ -1398,6 +1398,14 @@ func (st *searchTask) getChannels() ([]pChan, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = st.chMgr.getChannels(collID)
|
||||||
|
if err != nil {
|
||||||
|
err := st.chMgr.createDMLMsgStream(collID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return st.chMgr.getChannels(collID)
|
return st.chMgr.getChannels(collID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1474,10 +1482,7 @@ func (st *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
|
|
||||||
st.Base.MsgType = commonpb.MsgType_Search
|
st.Base.MsgType = commonpb.MsgType_Search
|
||||||
|
|
||||||
schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName)
|
||||||
if err != nil { // err is not nil if collection not exists
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
outputFields, err := translateOutputFields(st.query.OutputFields, schema, false)
|
outputFields, err := translateOutputFields(st.query.OutputFields, schema, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -1575,11 +1580,7 @@ func (st *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
|
|
||||||
st.SearchRequest.ResultChannelID = Params.SearchResultChannelNames[0]
|
st.SearchRequest.ResultChannelID = Params.SearchResultChannelNames[0]
|
||||||
st.SearchRequest.DbID = 0 // todo
|
st.SearchRequest.DbID = 0 // todo
|
||||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
st.SearchRequest.CollectionID = collID
|
||||||
if err != nil { // err is not nil if collection not exists
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
st.SearchRequest.CollectionID = collectionID
|
|
||||||
st.SearchRequest.PartitionIDs = make([]UniqueID, 0)
|
st.SearchRequest.PartitionIDs = make([]UniqueID, 0)
|
||||||
|
|
||||||
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
|
partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -12,6 +13,8 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus/internal/proto/querypb"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/msgstream"
|
"github.com/milvus-io/milvus/internal/msgstream"
|
||||||
|
|
||||||
"github.com/milvus-io/milvus/internal/util/distance"
|
"github.com/milvus-io/milvus/internal/util/distance"
|
||||||
@ -1467,8 +1470,22 @@ func TestSearchTask_all(t *testing.T) {
|
|||||||
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
qc := newMockQueryCoordShowCollectionsInterface()
|
qc := NewQueryCoordMock()
|
||||||
qc.addCollection(collectionID, 100)
|
qc.Start()
|
||||||
|
defer qc.Stop()
|
||||||
|
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_LoadCollection,
|
||||||
|
MsgID: 0,
|
||||||
|
Timestamp: 0,
|
||||||
|
SourceID: Params.ProxyID,
|
||||||
|
},
|
||||||
|
DbID: 0,
|
||||||
|
CollectionID: collectionID,
|
||||||
|
Schema: nil,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
|
||||||
|
|
||||||
req := constructSearchRequest(dbName, collectionName,
|
req := constructSearchRequest(dbName, collectionName,
|
||||||
expr,
|
expr,
|
||||||
@ -1604,3 +1621,340 @@ func TestSearchTask_all(t *testing.T) {
|
|||||||
cancel()
|
cancel()
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSearchTask_Type(t *testing.T) {
|
||||||
|
Params.Init()
|
||||||
|
|
||||||
|
task := &searchTask{
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NoError(t, task.OnEnqueue())
|
||||||
|
assert.Equal(t, commonpb.MsgType_Search, task.Type())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchTask_Ts(t *testing.T) {
|
||||||
|
Params.Init()
|
||||||
|
|
||||||
|
task := &searchTask{
|
||||||
|
SearchRequest: &internalpb.SearchRequest{
|
||||||
|
Base: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert.NoError(t, task.OnEnqueue())
|
||||||
|
|
||||||
|
ts := Timestamp(time.Now().Nanosecond())
|
||||||
|
task.SetTs(ts)
|
||||||
|
assert.Equal(t, ts, task.BeginTs())
|
||||||
|
assert.Equal(t, ts, task.EndTs())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchTask_Channels(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
Params.Init()
|
||||||
|
|
||||||
|
rc := NewRootCoordMock()
|
||||||
|
rc.Start()
|
||||||
|
defer rc.Stop()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = InitMetaCache(rc)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||||
|
query := newMockGetChannelsService()
|
||||||
|
factory := newSimpleMockMsgStreamFactory()
|
||||||
|
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
|
||||||
|
defer chMgr.removeAllDMLStream()
|
||||||
|
defer chMgr.removeAllDQLStream()
|
||||||
|
|
||||||
|
prefix := "TestSearchTask_Channels"
|
||||||
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
|
shardsNum := int32(2)
|
||||||
|
dbName := ""
|
||||||
|
int64Field := "int64"
|
||||||
|
floatVecField := "fvec"
|
||||||
|
dim := 128
|
||||||
|
|
||||||
|
task := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
query: &milvuspb.SearchRequest{
|
||||||
|
CollectionName: collectionName,
|
||||||
|
},
|
||||||
|
chMgr: chMgr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// collection not exist
|
||||||
|
_, err = task.getVChannels()
|
||||||
|
assert.Error(t, err)
|
||||||
|
_, err = task.getVChannels()
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
|
||||||
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
createColT := &createCollectionTask{
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||||
|
Base: nil,
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Schema: marshaledSchema,
|
||||||
|
ShardsNum: shardsNum,
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
rootCoord: rc,
|
||||||
|
result: nil,
|
||||||
|
schema: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, createColT.OnEnqueue())
|
||||||
|
assert.NoError(t, createColT.PreExecute(ctx))
|
||||||
|
assert.NoError(t, createColT.Execute(ctx))
|
||||||
|
assert.NoError(t, createColT.PostExecute(ctx))
|
||||||
|
|
||||||
|
_, err = task.getChannels()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
_, err = task.getVChannels()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
_ = chMgr.removeAllDMLStream()
|
||||||
|
chMgr.dmlChannelsMgr.getChannelsFunc = func(collectionID UniqueID) (map[vChan]pChan, error) {
|
||||||
|
return nil, errors.New("mock")
|
||||||
|
}
|
||||||
|
_, err = task.getChannels()
|
||||||
|
assert.Error(t, err)
|
||||||
|
_, err = task.getVChannels()
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSearchTask_PreExecute(t *testing.T) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
Params.Init()
|
||||||
|
Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()}
|
||||||
|
|
||||||
|
rc := NewRootCoordMock()
|
||||||
|
rc.Start()
|
||||||
|
defer rc.Stop()
|
||||||
|
|
||||||
|
qc := NewQueryCoordMock()
|
||||||
|
qc.Start()
|
||||||
|
defer qc.Stop()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err = InitMetaCache(rc)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
|
||||||
|
query := newMockGetChannelsService()
|
||||||
|
factory := newSimpleMockMsgStreamFactory()
|
||||||
|
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
|
||||||
|
defer chMgr.removeAllDMLStream()
|
||||||
|
defer chMgr.removeAllDQLStream()
|
||||||
|
|
||||||
|
prefix := "TestSearchTask_PreExecute"
|
||||||
|
collectionName := prefix + funcutil.GenRandomStr()
|
||||||
|
shardsNum := int32(2)
|
||||||
|
dbName := ""
|
||||||
|
int64Field := "int64"
|
||||||
|
floatVecField := "fvec"
|
||||||
|
dim := 128
|
||||||
|
|
||||||
|
task := &searchTask{
|
||||||
|
ctx: ctx,
|
||||||
|
SearchRequest: &internalpb.SearchRequest{},
|
||||||
|
query: &milvuspb.SearchRequest{
|
||||||
|
CollectionName: collectionName,
|
||||||
|
},
|
||||||
|
chMgr: chMgr,
|
||||||
|
qc: qc,
|
||||||
|
}
|
||||||
|
assert.NoError(t, task.OnEnqueue())
|
||||||
|
|
||||||
|
// collection not exist
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
|
||||||
|
schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName)
|
||||||
|
marshaledSchema, err := proto.Marshal(schema)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
createColT := &createCollectionTask{
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
|
||||||
|
Base: nil,
|
||||||
|
DbName: dbName,
|
||||||
|
CollectionName: collectionName,
|
||||||
|
Schema: marshaledSchema,
|
||||||
|
ShardsNum: shardsNum,
|
||||||
|
},
|
||||||
|
ctx: ctx,
|
||||||
|
rootCoord: rc,
|
||||||
|
result: nil,
|
||||||
|
schema: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, createColT.OnEnqueue())
|
||||||
|
assert.NoError(t, createColT.PreExecute(ctx))
|
||||||
|
assert.NoError(t, createColT.Execute(ctx))
|
||||||
|
assert.NoError(t, createColT.PostExecute(ctx))
|
||||||
|
|
||||||
|
collectionID, _ := globalMetaCache.GetCollectionID(ctx, collectionName)
|
||||||
|
|
||||||
|
// ValidateCollectionName
|
||||||
|
task.query.CollectionName = "$"
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.CollectionName = collectionName
|
||||||
|
|
||||||
|
// Validate Partition
|
||||||
|
task.query.PartitionNames = []string{"$"}
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.PartitionNames = nil
|
||||||
|
|
||||||
|
// mock show collections of query coord
|
||||||
|
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||||
|
return nil, errors.New("mock")
|
||||||
|
})
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) {
|
||||||
|
return &querypb.ShowCollectionsResponse{
|
||||||
|
Status: &commonpb.Status{
|
||||||
|
ErrorCode: commonpb.ErrorCode_UnexpectedError,
|
||||||
|
Reason: "mock",
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
qc.ResetShowCollectionsFunc()
|
||||||
|
|
||||||
|
// collection not loaded
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
_, _ = qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
|
||||||
|
Base: &commonpb.MsgBase{
|
||||||
|
MsgType: commonpb.MsgType_LoadCollection,
|
||||||
|
MsgID: 0,
|
||||||
|
Timestamp: 0,
|
||||||
|
SourceID: 0,
|
||||||
|
},
|
||||||
|
DbID: 0,
|
||||||
|
CollectionID: collectionID,
|
||||||
|
Schema: nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
// no anns field
|
||||||
|
task.query.DslType = commonpb.DslType_BoolExprV1
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: floatVecField,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// no topk
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: floatVecField,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: TopKKey,
|
||||||
|
Value: "invalid",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalid topk
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: floatVecField,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: TopKKey,
|
||||||
|
Value: "10",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// no metric type
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: floatVecField,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: TopKKey,
|
||||||
|
Value: "10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: MetricTypeKey,
|
||||||
|
Value: distance.L2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// no search params
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: int64Field,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: TopKKey,
|
||||||
|
Value: "10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: MetricTypeKey,
|
||||||
|
Value: distance.L2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// failed to create query plan
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.SearchParams = []*commonpb.KeyValuePair{
|
||||||
|
{
|
||||||
|
Key: AnnsFieldKey,
|
||||||
|
Value: floatVecField,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: TopKKey,
|
||||||
|
Value: "10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: MetricTypeKey,
|
||||||
|
Value: distance.L2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Key: SearchParamsKey,
|
||||||
|
Value: `{"nprobe": 10}`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// field not exist
|
||||||
|
task.query.OutputFields = []string{int64Field + funcutil.GenRandomStr()}
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
// contain vector field
|
||||||
|
task.query.OutputFields = []string{floatVecField}
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
task.query.OutputFields = []string{int64Field}
|
||||||
|
|
||||||
|
// partition
|
||||||
|
rc.showPartitionsFunc = func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) {
|
||||||
|
return nil, errors.New("mock")
|
||||||
|
}
|
||||||
|
assert.Error(t, task.PreExecute(ctx))
|
||||||
|
rc.showPartitionsFunc = nil
|
||||||
|
|
||||||
|
// TODO(dragondriver): test partition-related error
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user