diff --git a/internal/proxy/interface_def.go b/internal/proxy/interface_def.go index 029e897cfd..586851c1af 100644 --- a/internal/proxy/interface_def.go +++ b/internal/proxy/interface_def.go @@ -41,11 +41,6 @@ type getChannelsService interface { GetChannels(collectionID UniqueID) (map[vChan]pChan, error) } -// queryCoordShowCollectionsInterface used in searchTask & queryTask -type queryCoordShowCollectionsInterface interface { - ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) -} - // createQueryChannelInterface defines CreateQueryChannel type createQueryChannelInterface interface { CreateQueryChannel(ctx context.Context, request *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error) diff --git a/internal/proxy/mock_test.go b/internal/proxy/mock_test.go index 0c17e8179d..d7c3a0482f 100644 --- a/internal/proxy/mock_test.go +++ b/internal/proxy/mock_test.go @@ -17,8 +17,6 @@ import ( "sync" "time" - "github.com/milvus-io/milvus/internal/proto/querypb" - "github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/msgstream" @@ -555,33 +553,3 @@ func generateHashKeys(numRows int) []uint32 { } return ret } - -type mockQueryCoordShowCollectionsInterface struct { - collectionIDs []int64 - inMemoryPercentages []int64 -} - -func (ins *mockQueryCoordShowCollectionsInterface) ShowCollections(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { - resp := &querypb.ShowCollectionsResponse{ - Status: &commonpb.Status{ - ErrorCode: commonpb.ErrorCode_Success, - Reason: "", - }, - CollectionIDs: ins.collectionIDs, - InMemoryPercentages: ins.inMemoryPercentages, - } - - return resp, nil -} - -func (ins *mockQueryCoordShowCollectionsInterface) addCollection(collectionID int64, inMemoryPercentage int64) { - ins.collectionIDs = append(ins.collectionIDs, collectionID) - ins.inMemoryPercentages = append(ins.inMemoryPercentages, collectionID) -} - -func newMockQueryCoordShowCollectionsInterface() *mockQueryCoordShowCollectionsInterface { - return &mockQueryCoordShowCollectionsInterface{ - collectionIDs: make([]int64, 0), - inMemoryPercentages: make([]int64, 0), - } -} diff --git a/internal/proxy/query_coord_mock_test.go b/internal/proxy/query_coord_mock_test.go new file mode 100644 index 0000000000..e080953114 --- /dev/null +++ b/internal/proxy/query_coord_mock_test.go @@ -0,0 +1,312 @@ +package proxy + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + + "github.com/milvus-io/milvus/internal/util/funcutil" + "github.com/milvus-io/milvus/internal/util/uniquegenerator" + + "github.com/milvus-io/milvus/internal/proto/commonpb" + "github.com/milvus-io/milvus/internal/proto/internalpb" + "github.com/milvus-io/milvus/internal/proto/milvuspb" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/util/typeutil" +) + +type QueryCoordMockOption func(mock *QueryCoordMock) + +type queryCoordShowCollectionsFuncType func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) + +func SetQueryCoordShowCollectionsFunc(f queryCoordShowCollectionsFuncType) QueryCoordMockOption { + return func(mock *QueryCoordMock) { + mock.showCollectionsFunc = f + } +} + +type QueryCoordMock struct { + nodeID typeutil.UniqueID + address string + + state atomic.Value // internal.StateCode + + collectionIDs []int64 + inMemoryPercentages []int64 + colMtx sync.RWMutex + + showCollectionsFunc queryCoordShowCollectionsFuncType + + statisticsChannel string + timeTickChannel string +} + +func (coord *QueryCoordMock) updateState(state internalpb.StateCode) { + coord.state.Store(state) +} + +func (coord *QueryCoordMock) getState() internalpb.StateCode { + return coord.state.Load().(internalpb.StateCode) +} + +func (coord *QueryCoordMock) healthy() bool { + return coord.getState() == internalpb.StateCode_Healthy +} + +func (coord *QueryCoordMock) Init() error { + coord.updateState(internalpb.StateCode_Initializing) + return nil +} + +func (coord *QueryCoordMock) Start() error { + defer coord.updateState(internalpb.StateCode_Healthy) + + return nil +} + +func (coord *QueryCoordMock) Stop() error { + defer coord.updateState(internalpb.StateCode_Abnormal) + + return nil +} + +func (coord *QueryCoordMock) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) { + return &internalpb.ComponentStates{ + State: &internalpb.ComponentInfo{ + NodeID: coord.nodeID, + Role: typeutil.QueryCoordRole, + StateCode: coord.getState(), + ExtraInfo: nil, + }, + SubcomponentStates: nil, + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + }, nil +} + +func (coord *QueryCoordMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + Value: coord.statisticsChannel, + }, nil +} + +func (coord *QueryCoordMock) Register() error { + return nil +} + +func (coord *QueryCoordMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) { + return &milvuspb.StringResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + Value: coord.timeTickChannel, + }, nil +} + +func (coord *QueryCoordMock) ResetShowCollectionsFunc() { + coord.showCollectionsFunc = nil +} + +func (coord *QueryCoordMock) SetShowCollectionsFunc(f queryCoordShowCollectionsFuncType) { + coord.showCollectionsFunc = f +} + +func (coord *QueryCoordMock) ShowCollections(ctx context.Context, req *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + if !coord.healthy() { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, + }, nil + } + + if coord.showCollectionsFunc != nil { + return coord.showCollectionsFunc(ctx, req) + } + + coord.colMtx.RLock() + defer coord.colMtx.RUnlock() + + resp := &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, + CollectionIDs: coord.collectionIDs, + InMemoryPercentages: coord.inMemoryPercentages, + } + + return resp, nil +} + +func (coord *QueryCoordMock) LoadCollection(ctx context.Context, req *querypb.LoadCollectionRequest) (*commonpb.Status, error) { + if !coord.healthy() { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, nil + } + + coord.colMtx.Lock() + defer coord.colMtx.Unlock() + + for _, colID := range coord.collectionIDs { + if req.CollectionID == colID { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("collection %v already loaded", req.CollectionID), + }, nil + } + } + + coord.collectionIDs = append(coord.collectionIDs, req.CollectionID) + coord.inMemoryPercentages = append(coord.inMemoryPercentages, 100) + + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, nil +} + +func (coord *QueryCoordMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) { + if !coord.healthy() { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, nil + } + + coord.colMtx.Lock() + defer coord.colMtx.Unlock() + + for i := len(coord.collectionIDs) - 1; i >= 0; i-- { + if req.CollectionID == coord.collectionIDs[i] { + coord.collectionIDs = append(coord.collectionIDs[:i], coord.collectionIDs[i+1:]...) + coord.inMemoryPercentages = append(coord.inMemoryPercentages[:i], coord.inMemoryPercentages[i+1:]...) + + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_Success, + Reason: "", + }, nil + } + } + + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: fmt.Sprintf("collection %v not loaded", req.CollectionID), + }, nil +} + +func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { + if !coord.healthy() { + return &querypb.ShowPartitionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, + }, nil + } + + panic("implement me") +} + +func (coord *QueryCoordMock) LoadPartitions(ctx context.Context, req *querypb.LoadPartitionsRequest) (*commonpb.Status, error) { + if !coord.healthy() { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, nil + } + + panic("implement me") +} + +func (coord *QueryCoordMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) { + if !coord.healthy() { + return &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, nil + } + + panic("implement me") +} + +func (coord *QueryCoordMock) CreateQueryChannel(ctx context.Context, req *querypb.CreateQueryChannelRequest) (*querypb.CreateQueryChannelResponse, error) { + if !coord.healthy() { + return &querypb.CreateQueryChannelResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, + }, nil + } + + panic("implement me") +} + +func (coord *QueryCoordMock) GetPartitionStates(ctx context.Context, req *querypb.GetPartitionStatesRequest) (*querypb.GetPartitionStatesResponse, error) { + if !coord.healthy() { + return &querypb.GetPartitionStatesResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, + }, nil + } + + panic("implement me") +} + +func (coord *QueryCoordMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) { + if !coord.healthy() { + return &querypb.GetSegmentInfoResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, + }, nil + } + + panic("implement me") +} + +func (coord *QueryCoordMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { + if !coord.healthy() { + return &milvuspb.GetMetricsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "unhealthy", + }, + }, nil + } + panic("implement me") +} + +func NewQueryCoordMock(opts ...QueryCoordMockOption) *QueryCoordMock { + coord := &QueryCoordMock{ + nodeID: UniqueID(uniquegenerator.GetUniqueIntGeneratorIns().GetInt()), + address: funcutil.GenRandomStr(), // TODO(dragondriver): random address + state: atomic.Value{}, + collectionIDs: make([]int64, 0), + inMemoryPercentages: make([]int64, 0), + colMtx: sync.RWMutex{}, + statisticsChannel: funcutil.GenRandomStr(), + timeTickChannel: funcutil.GenRandomStr(), + } + + for _, opt := range opts { + opt(coord) + } + + return coord +} diff --git a/internal/proxy/rootcoord_mock_test.go b/internal/proxy/rootcoord_mock_test.go index 270df7369c..293eb1b605 100644 --- a/internal/proxy/rootcoord_mock_test.go +++ b/internal/proxy/rootcoord_mock_test.go @@ -73,11 +73,7 @@ type RootCoordMockOption func(mock *RootCoordMock) type describeCollectionFuncType func(ctx context.Context, request *milvuspb.DescribeCollectionRequest) (*milvuspb.DescribeCollectionResponse, error) -func SetDescribeCollectionFunc(f describeCollectionFuncType) RootCoordMockOption { - return func(mock *RootCoordMock) { - mock.SetDescribeCollectionFunc(f) - } -} +type showPartitionsFuncType func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) type RootCoordMock struct { nodeID typeutil.UniqueID @@ -98,6 +94,7 @@ type RootCoordMock struct { partitionMtx sync.RWMutex describeCollectionFunc describeCollectionFuncType + showPartitionsFunc showPartitionsFuncType // TODO(dragondriver): index-related @@ -332,7 +329,7 @@ func (coord *RootCoordMock) SetDescribeCollectionFunc(f describeCollectionFuncTy coord.describeCollectionFunc = f } -func (coord *RootCoordMock) ResetDescribeCollectionFunc(f describeCollectionFuncType) { +func (coord *RootCoordMock) ResetDescribeCollectionFunc() { coord.describeCollectionFunc = nil } @@ -560,6 +557,11 @@ func (coord *RootCoordMock) ShowPartitions(ctx context.Context, req *milvuspb.Sh PartitionIDs: nil, }, nil } + + if coord.showPartitionsFunc != nil { + return coord.showPartitionsFunc(ctx, req) + } + coord.collMtx.RLock() defer coord.collMtx.RUnlock() diff --git a/internal/proxy/task.go b/internal/proxy/task.go index f7d2abb1ad..e9a0f197c2 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1350,7 +1350,7 @@ type searchTask struct { result *milvuspb.SearchResults query *milvuspb.SearchRequest chMgr channelsMgr - qc queryCoordShowCollectionsInterface + qc types.QueryCoord } func (st *searchTask) TraceCtx() context.Context { @@ -1398,6 +1398,14 @@ func (st *searchTask) getChannels() ([]pChan, error) { return nil, err } + _, err = st.chMgr.getChannels(collID) + if err != nil { + err := st.chMgr.createDMLMsgStream(collID) + if err != nil { + return nil, err + } + } + return st.chMgr.getChannels(collID) } @@ -1474,10 +1482,7 @@ func (st *searchTask) PreExecute(ctx context.Context) error { st.Base.MsgType = commonpb.MsgType_Search - schema, err := globalMetaCache.GetCollectionSchema(ctx, collectionName) - if err != nil { // err is not nil if collection not exists - return err - } + schema, _ := globalMetaCache.GetCollectionSchema(ctx, collectionName) outputFields, err := translateOutputFields(st.query.OutputFields, schema, false) if err != nil { @@ -1575,11 +1580,7 @@ func (st *searchTask) PreExecute(ctx context.Context) error { st.SearchRequest.ResultChannelID = Params.SearchResultChannelNames[0] st.SearchRequest.DbID = 0 // todo - collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) - if err != nil { // err is not nil if collection not exists - return err - } - st.SearchRequest.CollectionID = collectionID + st.SearchRequest.CollectionID = collID st.SearchRequest.PartitionIDs = make([]UniqueID, 0) partitionsMap, err := globalMetaCache.GetPartitions(ctx, collectionName) diff --git a/internal/proxy/task_test.go b/internal/proxy/task_test.go index ff7affd314..a62c882120 100644 --- a/internal/proxy/task_test.go +++ b/internal/proxy/task_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/binary" "encoding/json" + "errors" "fmt" "math/rand" "strconv" @@ -12,6 +13,8 @@ import ( "testing" "time" + "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/msgstream" "github.com/milvus-io/milvus/internal/util/distance" @@ -1467,8 +1470,22 @@ func TestSearchTask_all(t *testing.T) { collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) assert.NoError(t, err) - qc := newMockQueryCoordShowCollectionsInterface() - qc.addCollection(collectionID, 100) + qc := NewQueryCoordMock() + qc.Start() + defer qc.Stop() + status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + MsgID: 0, + Timestamp: 0, + SourceID: Params.ProxyID, + }, + DbID: 0, + CollectionID: collectionID, + Schema: nil, + }) + assert.NoError(t, err) + assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) req := constructSearchRequest(dbName, collectionName, expr, @@ -1604,3 +1621,340 @@ func TestSearchTask_all(t *testing.T) { cancel() wg.Wait() } + +func TestSearchTask_Type(t *testing.T) { + Params.Init() + + task := &searchTask{ + SearchRequest: &internalpb.SearchRequest{ + Base: nil, + }, + } + assert.NoError(t, task.OnEnqueue()) + assert.Equal(t, commonpb.MsgType_Search, task.Type()) +} + +func TestSearchTask_Ts(t *testing.T) { + Params.Init() + + task := &searchTask{ + SearchRequest: &internalpb.SearchRequest{ + Base: nil, + }, + } + assert.NoError(t, task.OnEnqueue()) + + ts := Timestamp(time.Now().Nanosecond()) + task.SetTs(ts) + assert.Equal(t, ts, task.BeginTs()) + assert.Equal(t, ts, task.EndTs()) +} + +func TestSearchTask_Channels(t *testing.T) { + var err error + + Params.Init() + + rc := NewRootCoordMock() + rc.Start() + defer rc.Stop() + + ctx := context.Background() + + err = InitMetaCache(rc) + assert.NoError(t, err) + + dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) + query := newMockGetChannelsService() + factory := newSimpleMockMsgStreamFactory() + chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) + defer chMgr.removeAllDMLStream() + defer chMgr.removeAllDQLStream() + + prefix := "TestSearchTask_Channels" + collectionName := prefix + funcutil.GenRandomStr() + shardsNum := int32(2) + dbName := "" + int64Field := "int64" + floatVecField := "fvec" + dim := 128 + + task := &searchTask{ + ctx: ctx, + query: &milvuspb.SearchRequest{ + CollectionName: collectionName, + }, + chMgr: chMgr, + } + + // collection not exist + _, err = task.getVChannels() + assert.Error(t, err) + _, err = task.getVChannels() + assert.Error(t, err) + + schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createColT := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + assert.NoError(t, createColT.OnEnqueue()) + assert.NoError(t, createColT.PreExecute(ctx)) + assert.NoError(t, createColT.Execute(ctx)) + assert.NoError(t, createColT.PostExecute(ctx)) + + _, err = task.getChannels() + assert.NoError(t, err) + _, err = task.getVChannels() + assert.NoError(t, err) + + _ = chMgr.removeAllDMLStream() + chMgr.dmlChannelsMgr.getChannelsFunc = func(collectionID UniqueID) (map[vChan]pChan, error) { + return nil, errors.New("mock") + } + _, err = task.getChannels() + assert.Error(t, err) + _, err = task.getVChannels() + assert.Error(t, err) +} + +func TestSearchTask_PreExecute(t *testing.T) { + var err error + + Params.Init() + Params.SearchResultChannelNames = []string{funcutil.GenRandomStr()} + + rc := NewRootCoordMock() + rc.Start() + defer rc.Stop() + + qc := NewQueryCoordMock() + qc.Start() + defer qc.Stop() + + ctx := context.Background() + + err = InitMetaCache(rc) + assert.NoError(t, err) + + dmlChannelsFunc := getDmlChannelsFunc(ctx, rc) + query := newMockGetChannelsService() + factory := newSimpleMockMsgStreamFactory() + chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory) + defer chMgr.removeAllDMLStream() + defer chMgr.removeAllDQLStream() + + prefix := "TestSearchTask_PreExecute" + collectionName := prefix + funcutil.GenRandomStr() + shardsNum := int32(2) + dbName := "" + int64Field := "int64" + floatVecField := "fvec" + dim := 128 + + task := &searchTask{ + ctx: ctx, + SearchRequest: &internalpb.SearchRequest{}, + query: &milvuspb.SearchRequest{ + CollectionName: collectionName, + }, + chMgr: chMgr, + qc: qc, + } + assert.NoError(t, task.OnEnqueue()) + + // collection not exist + assert.Error(t, task.PreExecute(ctx)) + + schema := constructCollectionSchema(int64Field, floatVecField, dim, collectionName) + marshaledSchema, err := proto.Marshal(schema) + assert.NoError(t, err) + + createColT := &createCollectionTask{ + Condition: NewTaskCondition(ctx), + CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: shardsNum, + }, + ctx: ctx, + rootCoord: rc, + result: nil, + schema: nil, + } + + assert.NoError(t, createColT.OnEnqueue()) + assert.NoError(t, createColT.PreExecute(ctx)) + assert.NoError(t, createColT.Execute(ctx)) + assert.NoError(t, createColT.PostExecute(ctx)) + + collectionID, _ := globalMetaCache.GetCollectionID(ctx, collectionName) + + // ValidateCollectionName + task.query.CollectionName = "$" + assert.Error(t, task.PreExecute(ctx)) + task.query.CollectionName = collectionName + + // Validate Partition + task.query.PartitionNames = []string{"$"} + assert.Error(t, task.PreExecute(ctx)) + task.query.PartitionNames = nil + + // mock show collections of query coord + qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return nil, errors.New("mock") + }) + assert.Error(t, task.PreExecute(ctx)) + qc.SetShowCollectionsFunc(func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) { + return &querypb.ShowCollectionsResponse{ + Status: &commonpb.Status{ + ErrorCode: commonpb.ErrorCode_UnexpectedError, + Reason: "mock", + }, + }, nil + }) + assert.Error(t, task.PreExecute(ctx)) + qc.ResetShowCollectionsFunc() + + // collection not loaded + assert.Error(t, task.PreExecute(ctx)) + _, _ = qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_LoadCollection, + MsgID: 0, + Timestamp: 0, + SourceID: 0, + }, + DbID: 0, + CollectionID: collectionID, + Schema: nil, + }) + + // no anns field + task.query.DslType = commonpb.DslType_BoolExprV1 + assert.Error(t, task.PreExecute(ctx)) + task.query.SearchParams = []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: floatVecField, + }, + } + + // no topk + assert.Error(t, task.PreExecute(ctx)) + task.query.SearchParams = []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: floatVecField, + }, + { + Key: TopKKey, + Value: "invalid", + }, + } + + // invalid topk + assert.Error(t, task.PreExecute(ctx)) + task.query.SearchParams = []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: floatVecField, + }, + { + Key: TopKKey, + Value: "10", + }, + } + + // no metric type + assert.Error(t, task.PreExecute(ctx)) + task.query.SearchParams = []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: floatVecField, + }, + { + Key: TopKKey, + Value: "10", + }, + { + Key: MetricTypeKey, + Value: distance.L2, + }, + } + + // no search params + assert.Error(t, task.PreExecute(ctx)) + task.query.SearchParams = []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: int64Field, + }, + { + Key: TopKKey, + Value: "10", + }, + { + Key: MetricTypeKey, + Value: distance.L2, + }, + { + Key: SearchParamsKey, + Value: `{"nprobe": 10}`, + }, + } + + // failed to create query plan + assert.Error(t, task.PreExecute(ctx)) + task.query.SearchParams = []*commonpb.KeyValuePair{ + { + Key: AnnsFieldKey, + Value: floatVecField, + }, + { + Key: TopKKey, + Value: "10", + }, + { + Key: MetricTypeKey, + Value: distance.L2, + }, + { + Key: SearchParamsKey, + Value: `{"nprobe": 10}`, + }, + } + + // field not exist + task.query.OutputFields = []string{int64Field + funcutil.GenRandomStr()} + assert.Error(t, task.PreExecute(ctx)) + // contain vector field + task.query.OutputFields = []string{floatVecField} + assert.Error(t, task.PreExecute(ctx)) + task.query.OutputFields = []string{int64Field} + + // partition + rc.showPartitionsFunc = func(ctx context.Context, request *milvuspb.ShowPartitionsRequest) (*milvuspb.ShowPartitionsResponse, error) { + return nil, errors.New("mock") + } + assert.Error(t, task.PreExecute(ctx)) + rc.showPartitionsFunc = nil + + // TODO(dragondriver): test partition-related error +}