diff --git a/Makefile b/Makefile index 8ca4f3d629..acf6b673be 100644 --- a/Makefile +++ b/Makefile @@ -408,6 +408,7 @@ generate-mockery-querycoord: getdeps generate-mockery-querynode: getdeps build-cpp @source $(PWD)/scripts/setenv.sh # setup PKG_CONFIG_PATH + $(INSTALL_PATH)/mockery --name=QueryHook --dir=$(PWD)/internal/querynodev2/optimizers --output=$(PWD)/internal/querynodev2/optimizers --filename=mock_query_hook.go --with-expecter --outpkg=optimizers --structname=MockQueryHook --inpackage $(INSTALL_PATH)/mockery --name=Manager --dir=$(PWD)/internal/querynodev2/cluster --output=$(PWD)/internal/querynodev2/cluster --filename=mock_manager.go --with-expecter --outpkg=cluster --structname=MockManager --inpackage $(INSTALL_PATH)/mockery --name=SegmentManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_segment_manager.go --with-expecter --outpkg=segments --structname=MockSegmentManager --inpackage $(INSTALL_PATH)/mockery --name=CollectionManager --dir=$(PWD)/internal/querynodev2/segments --output=$(PWD)/internal/querynodev2/segments --filename=mock_collection_manager.go --with-expecter --outpkg=segments --structname=MockCollectionManager --inpackage diff --git a/internal/querynodev2/delegator/mock_delegator.go b/internal/querynodev2/delegator/mock_delegator.go index 279eb88ee7..c1f5e95e0c 100644 --- a/internal/querynodev2/delegator/mock_delegator.go +++ b/internal/querynodev2/delegator/mock_delegator.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.16.0. DO NOT EDIT. +// Code generated by mockery v2.32.4. DO NOT EDIT. package delegator @@ -53,6 +53,11 @@ func (_c *MockShardDelegator_Close_Call) Return() *MockShardDelegator_Close_Call return _c } +func (_c *MockShardDelegator_Close_Call) RunAndReturn(run func()) *MockShardDelegator_Close_Call { + _c.Call.Return(run) + return _c +} + // Collection provides a mock function with given fields: func (_m *MockShardDelegator) Collection() int64 { ret := _m.Called() @@ -89,11 +94,20 @@ func (_c *MockShardDelegator_Collection_Call) Return(_a0 int64) *MockShardDelega return _c } +func (_c *MockShardDelegator_Collection_Call) RunAndReturn(run func() int64) *MockShardDelegator_Collection_Call { + _c.Call.Return(run) + return _c +} + // GetSegmentInfo provides a mock function with given fields: readable func (_m *MockShardDelegator) GetSegmentInfo(readable bool) ([]SnapshotItem, []SegmentEntry) { ret := _m.Called(readable) var r0 []SnapshotItem + var r1 []SegmentEntry + if rf, ok := ret.Get(0).(func(bool) ([]SnapshotItem, []SegmentEntry)); ok { + return rf(readable) + } if rf, ok := ret.Get(0).(func(bool) []SnapshotItem); ok { r0 = rf(readable) } else { @@ -102,7 +116,6 @@ func (_m *MockShardDelegator) GetSegmentInfo(readable bool) ([]SnapshotItem, []S } } - var r1 []SegmentEntry if rf, ok := ret.Get(1).(func(bool) []SegmentEntry); ok { r1 = rf(readable) } else { @@ -137,11 +150,20 @@ func (_c *MockShardDelegator_GetSegmentInfo_Call) Return(sealed []SnapshotItem, return _c } +func (_c *MockShardDelegator_GetSegmentInfo_Call) RunAndReturn(run func(bool) ([]SnapshotItem, []SegmentEntry)) *MockShardDelegator_GetSegmentInfo_Call { + _c.Call.Return(run) + return _c +} + // GetStatistics provides a mock function with given fields: ctx, req func (_m *MockShardDelegator) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error) { ret := _m.Called(ctx, req) var r0 []*internalpb.GetStatisticsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)); ok { + return rf(ctx, req) + } if rf, ok := ret.Get(0).(func(context.Context, *querypb.GetStatisticsRequest) []*internalpb.GetStatisticsResponse); ok { r0 = rf(ctx, req) } else { @@ -150,7 +172,6 @@ func (_m *MockShardDelegator) GetStatistics(ctx context.Context, req *querypb.Ge } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, *querypb.GetStatisticsRequest) error); ok { r1 = rf(ctx, req) } else { @@ -184,6 +205,11 @@ func (_c *MockShardDelegator_GetStatistics_Call) Return(_a0 []*internalpb.GetSta return _c } +func (_c *MockShardDelegator_GetStatistics_Call) RunAndReturn(run func(context.Context, *querypb.GetStatisticsRequest) ([]*internalpb.GetStatisticsResponse, error)) *MockShardDelegator_GetStatistics_Call { + _c.Call.Return(run) + return _c +} + // GetTargetVersion provides a mock function with given fields: func (_m *MockShardDelegator) GetTargetVersion() int64 { ret := _m.Called() @@ -220,6 +246,11 @@ func (_c *MockShardDelegator_GetTargetVersion_Call) Return(_a0 int64) *MockShard return _c } +func (_c *MockShardDelegator_GetTargetVersion_Call) RunAndReturn(run func() int64) *MockShardDelegator_GetTargetVersion_Call { + _c.Call.Return(run) + return _c +} + // LoadGrowing provides a mock function with given fields: ctx, infos, version func (_m *MockShardDelegator) LoadGrowing(ctx context.Context, infos []*querypb.SegmentLoadInfo, version int64) error { ret := _m.Called(ctx, infos, version) @@ -259,6 +290,11 @@ func (_c *MockShardDelegator_LoadGrowing_Call) Return(_a0 error) *MockShardDeleg return _c } +func (_c *MockShardDelegator_LoadGrowing_Call) RunAndReturn(run func(context.Context, []*querypb.SegmentLoadInfo, int64) error) *MockShardDelegator_LoadGrowing_Call { + _c.Call.Return(run) + return _c +} + // LoadSegments provides a mock function with given fields: ctx, req func (_m *MockShardDelegator) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error { ret := _m.Called(ctx, req) @@ -297,6 +333,11 @@ func (_c *MockShardDelegator_LoadSegments_Call) Return(_a0 error) *MockShardDele return _c } +func (_c *MockShardDelegator_LoadSegments_Call) RunAndReturn(run func(context.Context, *querypb.LoadSegmentsRequest) error) *MockShardDelegator_LoadSegments_Call { + _c.Call.Return(run) + return _c +} + // ProcessDelete provides a mock function with given fields: deleteData, ts func (_m *MockShardDelegator) ProcessDelete(deleteData []*DeleteData, ts uint64) { _m.Called(deleteData, ts) @@ -326,6 +367,11 @@ func (_c *MockShardDelegator_ProcessDelete_Call) Return() *MockShardDelegator_Pr return _c } +func (_c *MockShardDelegator_ProcessDelete_Call) RunAndReturn(run func([]*DeleteData, uint64)) *MockShardDelegator_ProcessDelete_Call { + _c.Call.Return(run) + return _c +} + // ProcessInsert provides a mock function with given fields: insertRecords func (_m *MockShardDelegator) ProcessInsert(insertRecords map[int64]*InsertData) { _m.Called(insertRecords) @@ -354,11 +400,20 @@ func (_c *MockShardDelegator_ProcessInsert_Call) Return() *MockShardDelegator_Pr return _c } +func (_c *MockShardDelegator_ProcessInsert_Call) RunAndReturn(run func(map[int64]*InsertData)) *MockShardDelegator_ProcessInsert_Call { + _c.Call.Return(run) + return _c +} + // Query provides a mock function with given fields: ctx, req func (_m *MockShardDelegator) Query(ctx context.Context, req *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error) { ret := _m.Called(ctx, req) var r0 []*internalpb.RetrieveResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)); ok { + return rf(ctx, req) + } if rf, ok := ret.Get(0).(func(context.Context, *querypb.QueryRequest) []*internalpb.RetrieveResults); ok { r0 = rf(ctx, req) } else { @@ -367,7 +422,6 @@ func (_m *MockShardDelegator) Query(ctx context.Context, req *querypb.QueryReque } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, *querypb.QueryRequest) error); ok { r1 = rf(ctx, req) } else { @@ -401,6 +455,11 @@ func (_c *MockShardDelegator_Query_Call) Return(_a0 []*internalpb.RetrieveResult return _c } +func (_c *MockShardDelegator_Query_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest) ([]*internalpb.RetrieveResults, error)) *MockShardDelegator_Query_Call { + _c.Call.Return(run) + return _c +} + // QueryStream provides a mock function with given fields: ctx, req, srv func (_m *MockShardDelegator) QueryStream(ctx context.Context, req *querypb.QueryRequest, srv streamrpc.QueryStreamServer) error { ret := _m.Called(ctx, req, srv) @@ -440,6 +499,11 @@ func (_c *MockShardDelegator_QueryStream_Call) Return(_a0 error) *MockShardDeleg return _c } +func (_c *MockShardDelegator_QueryStream_Call) RunAndReturn(run func(context.Context, *querypb.QueryRequest, streamrpc.QueryStreamServer) error) *MockShardDelegator_QueryStream_Call { + _c.Call.Return(run) + return _c +} + // ReleaseSegments provides a mock function with given fields: ctx, req, force func (_m *MockShardDelegator) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest, force bool) error { ret := _m.Called(ctx, req, force) @@ -479,11 +543,20 @@ func (_c *MockShardDelegator_ReleaseSegments_Call) Return(_a0 error) *MockShardD return _c } +func (_c *MockShardDelegator_ReleaseSegments_Call) RunAndReturn(run func(context.Context, *querypb.ReleaseSegmentsRequest, bool) error) *MockShardDelegator_ReleaseSegments_Call { + _c.Call.Return(run) + return _c +} + // Search provides a mock function with given fields: ctx, req func (_m *MockShardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) { ret := _m.Called(ctx, req) var r0 []*internalpb.SearchResults + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) ([]*internalpb.SearchResults, error)); ok { + return rf(ctx, req) + } if rf, ok := ret.Get(0).(func(context.Context, *querypb.SearchRequest) []*internalpb.SearchResults); ok { r0 = rf(ctx, req) } else { @@ -492,7 +565,6 @@ func (_m *MockShardDelegator) Search(ctx context.Context, req *querypb.SearchReq } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, *querypb.SearchRequest) error); ok { r1 = rf(ctx, req) } else { @@ -526,6 +598,11 @@ func (_c *MockShardDelegator_Search_Call) Return(_a0 []*internalpb.SearchResults return _c } +func (_c *MockShardDelegator_Search_Call) RunAndReturn(run func(context.Context, *querypb.SearchRequest) ([]*internalpb.SearchResults, error)) *MockShardDelegator_Search_Call { + _c.Call.Return(run) + return _c +} + // Serviceable provides a mock function with given fields: func (_m *MockShardDelegator) Serviceable() bool { ret := _m.Called() @@ -562,6 +639,11 @@ func (_c *MockShardDelegator_Serviceable_Call) Return(_a0 bool) *MockShardDelega return _c } +func (_c *MockShardDelegator_Serviceable_Call) RunAndReturn(run func() bool) *MockShardDelegator_Serviceable_Call { + _c.Call.Return(run) + return _c +} + // Start provides a mock function with given fields: func (_m *MockShardDelegator) Start() { _m.Called() @@ -589,6 +671,11 @@ func (_c *MockShardDelegator_Start_Call) Return() *MockShardDelegator_Start_Call return _c } +func (_c *MockShardDelegator_Start_Call) RunAndReturn(run func()) *MockShardDelegator_Start_Call { + _c.Call.Return(run) + return _c +} + // SyncDistribution provides a mock function with given fields: ctx, entries func (_m *MockShardDelegator) SyncDistribution(ctx context.Context, entries ...SegmentEntry) { _va := make([]interface{}, len(entries)) @@ -719,8 +806,7 @@ func (_c *MockShardDelegator_Version_Call) RunAndReturn(run func() int64) *MockS func NewMockShardDelegator(t interface { mock.TestingT Cleanup(func()) -}, -) *MockShardDelegator { +}) *MockShardDelegator { mock := &MockShardDelegator{} mock.Mock.Test(t) diff --git a/internal/querynodev2/handlers_test.go b/internal/querynodev2/handlers_test.go index 774586a556..0088d6f531 100644 --- a/internal/querynodev2/handlers_test.go +++ b/internal/querynodev2/handlers_test.go @@ -31,6 +31,7 @@ import ( "github.com/milvus-io/milvus/internal/proto/planpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/delegator" + "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/pkg/common" @@ -180,7 +181,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { defer cancel() suite.Run("normal_run", func() { - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` @@ -237,7 +238,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { }) suite.Run("other_plannode", func() { - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` @@ -262,11 +263,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { }) suite.Run("no_serialized_plan", func() { - mockHook := &MockQueryHook{} - mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { - params[common.TopKKey] = int64(50) - params[common.SearchParamKey] = `{"param": 2}` - }).Return(nil) + mockHook := optimizers.NewMockQueryHook(suite.T()) suite.node.queryHook = mockHook defer func() { suite.node.queryHook = nil }() @@ -278,7 +275,7 @@ func (suite *OptimizeSearchParamSuite) TestOptimizeSearchParam() { }) suite.Run("hook_run_error", func() { - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` diff --git a/internal/querynodev2/mock_query_hook.go b/internal/querynodev2/optimizers/mock_query_hook.go similarity index 94% rename from internal/querynodev2/mock_query_hook.go rename to internal/querynodev2/optimizers/mock_query_hook.go index 79c19a42c0..7c9f5dab88 100644 --- a/internal/querynodev2/mock_query_hook.go +++ b/internal/querynodev2/optimizers/mock_query_hook.go @@ -1,7 +1,10 @@ -package querynodev2 +// Code generated by mockery v2.32.4. DO NOT EDIT. -import "github.com/stretchr/testify/mock" +package optimizers +import mock "github.com/stretchr/testify/mock" + +// MockQueryHook is an autogenerated mock type for the QueryHook type type MockQueryHook struct { mock.Mock } @@ -182,13 +185,12 @@ func (_c *MockQueryHook_Run_Call) RunAndReturn(run func(map[string]interface{}) return _c } -type mockConstructorTestingTNewMockQueryHook interface { +// NewMockQueryHook creates a new instance of MockQueryHook. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockQueryHook(t interface { mock.TestingT Cleanup(func()) -} - -// NewMockQueryHook creates a new instance of MockQueryHook. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -func NewMockQueryHook(t mockConstructorTestingTNewMockQueryHook) *MockQueryHook { +}) *MockQueryHook { mock := &MockQueryHook{} mock.Mock.Test(t) diff --git a/internal/querynodev2/optimizers/query_hook.go b/internal/querynodev2/optimizers/query_hook.go new file mode 100644 index 0000000000..c3703feba1 --- /dev/null +++ b/internal/querynodev2/optimizers/query_hook.go @@ -0,0 +1,9 @@ +package optimizers + +// QueryHook is the interface for search/query parameter optimizer. +type QueryHook interface { + Run(map[string]any) error + Init(string) error + InitTuningConfig(map[string]string) error + DeleteTuningConfig(string) error +} diff --git a/internal/querynodev2/segments/mock_segment.go b/internal/querynodev2/segments/mock_segment.go index b2bb133c56..e5377c7147 100644 --- a/internal/querynodev2/segments/mock_segment.go +++ b/internal/querynodev2/segments/mock_segment.go @@ -1024,49 +1024,6 @@ func (_c *MockSegment_UpdateBloomFilter_Call) RunAndReturn(run func([]storage.Pr return _c } -// ValidateIndexedFieldsData provides a mock function with given fields: ctx, result -func (_m *MockSegment) ValidateIndexedFieldsData(ctx context.Context, result *segcorepb.RetrieveResults) error { - ret := _m.Called(ctx, result) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *segcorepb.RetrieveResults) error); ok { - r0 = rf(ctx, result) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockSegment_ValidateIndexedFieldsData_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ValidateIndexedFieldsData' -type MockSegment_ValidateIndexedFieldsData_Call struct { - *mock.Call -} - -// ValidateIndexedFieldsData is a helper method to define mock.On call -// - ctx context.Context -// - result *segcorepb.RetrieveResults -func (_e *MockSegment_Expecter) ValidateIndexedFieldsData(ctx interface{}, result interface{}) *MockSegment_ValidateIndexedFieldsData_Call { - return &MockSegment_ValidateIndexedFieldsData_Call{Call: _e.mock.On("ValidateIndexedFieldsData", ctx, result)} -} - -func (_c *MockSegment_ValidateIndexedFieldsData_Call) Run(run func(ctx context.Context, result *segcorepb.RetrieveResults)) *MockSegment_ValidateIndexedFieldsData_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*segcorepb.RetrieveResults)) - }) - return _c -} - -func (_c *MockSegment_ValidateIndexedFieldsData_Call) Return(_a0 error) *MockSegment_ValidateIndexedFieldsData_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockSegment_ValidateIndexedFieldsData_Call) RunAndReturn(run func(context.Context, *segcorepb.RetrieveResults) error) *MockSegment_ValidateIndexedFieldsData_Call { - _c.Call.Return(run) - return _c -} - // Version provides a mock function with given fields: func (_m *MockSegment) Version() int64 { ret := _m.Called() diff --git a/internal/querynodev2/server.go b/internal/querynodev2/server.go index 2f13936661..7779e00093 100644 --- a/internal/querynodev2/server.go +++ b/internal/querynodev2/server.go @@ -49,6 +49,7 @@ import ( grpcquerynodeclient "github.com/milvus-io/milvus/internal/distributed/querynode/client" "github.com/milvus-io/milvus/internal/querynodev2/cluster" "github.com/milvus-io/milvus/internal/querynodev2/delegator" + "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/querynodev2/pipeline" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/querynodev2/tasks" @@ -127,7 +128,7 @@ type QueryNode struct { knnPool *conc.Pool*/ // parameter turning hook - queryHook queryHook + queryHook optimizers.QueryHook } // NewQueryNode will return a QueryNode with abnormal state. @@ -488,13 +489,6 @@ func (node *QueryNode) SetAddress(address string) { node.address = address } -type queryHook interface { - Run(map[string]any) error - Init(string) error - InitTuningConfig(map[string]string) error - DeleteTuningConfig(string) error -} - // initHook initializes parameter tuning hook. func (node *QueryNode) initHook() error { path := paramtable.Get().QueryNodeCfg.SoPath.GetValue() @@ -514,7 +508,7 @@ func (node *QueryNode) initHook() error { return fmt.Errorf("fail to find the 'QueryNodePlugin' object in the plugin, error: %s", err.Error()) } - hoo, ok := h.(queryHook) + hoo, ok := h.(optimizers.QueryHook) if !ok { return fmt.Errorf("fail to convert the `Hook` interface") } diff --git a/internal/querynodev2/server_test.go b/internal/querynodev2/server_test.go index a682a7e030..bc7c0bfb34 100644 --- a/internal/querynodev2/server_test.go +++ b/internal/querynodev2/server_test.go @@ -32,6 +32,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/querypb" + "github.com/milvus-io/milvus/internal/querynodev2/optimizers" "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/storage" "github.com/milvus-io/milvus/internal/util/dependency" @@ -157,7 +158,7 @@ func (suite *QueryNodeSuite) TestInit_QueryHook() { err = suite.node.Init() suite.NoError(err) - mockHook := &MockQueryHook{} + mockHook := optimizers.NewMockQueryHook(suite.T()) suite.node.queryHook = mockHook suite.node.handleQueryHookEvent()