diff --git a/internal/datacoord/.mockery.yaml b/internal/datacoord/.mockery.yaml index d618c870fe..958024928d 100644 --- a/internal/datacoord/.mockery.yaml +++ b/internal/datacoord/.mockery.yaml @@ -32,6 +32,7 @@ packages: ChannelManager: SubCluster: StatsJobManager: + ImportMeta: github.com/milvus-io/milvus/internal/datacoord/allocator: interfaces: Allocator: diff --git a/internal/datacoord/import_util.go b/internal/datacoord/import_util.go index a4fd30b6ec..6a9d33768c 100644 --- a/internal/datacoord/import_util.go +++ b/internal/datacoord/import_util.go @@ -22,6 +22,7 @@ import ( "math" "path" "sort" + "sync" "time" "github.com/cockroachdb/errors" @@ -39,6 +40,8 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/indexpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/taskcommon" + "github.com/milvus-io/milvus/pkg/v2/util/conc" + "github.com/milvus-io/milvus/pkg/v2/util/hardware" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/timerecord" @@ -549,7 +552,9 @@ func getIndexBuildingProgress(ctx context.Context, jobID int64, importMeta Impor // 10%: Completed // TODO: Wrap a function to map status to user status. // TODO: Save these progress to job instead of recalculating. -func GetJobProgress(ctx context.Context, jobID int64, importMeta ImportMeta, meta *meta, sjm StatsInspector) (int64, internalpb.ImportJobState, int64, int64, string) { +func GetJobProgress(ctx context.Context, jobID int64, + importMeta ImportMeta, meta *meta, sjm StatsInspector, +) (int64, internalpb.ImportJobState, int64, int64, string) { job := importMeta.GetJob(ctx, jobID) if job == nil { return 0, internalpb.ImportJobState_Failed, 0, 0, fmt.Sprintf("import job does not exist, jobID=%d", jobID) @@ -627,7 +632,9 @@ func DropImportTask(task ImportTask, cluster session.Cluster, tm ImportMeta) err return tm.UpdateTask(context.TODO(), task.GetTaskID(), UpdateNodeID(NullNodeID)) } -func ListBinlogsAndGroupBySegment(ctx context.Context, cm storage.ChunkManager, importFile *internalpb.ImportFile) ([]*internalpb.ImportFile, error) { +func ListBinlogsAndGroupBySegment(ctx context.Context, + cm storage.ChunkManager, importFile *internalpb.ImportFile, +) ([]*internalpb.ImportFile, error) { if len(importFile.GetPaths()) == 0 { return nil, merr.WrapErrImportFailed("no insert binlogs to import") } @@ -668,3 +675,73 @@ func ListBinlogsAndGroupBySegment(ctx context.Context, cm storage.ChunkManager, } return segmentImportFiles, nil } + +// ValidateBinlogImportRequest validates the binlog import request. +func ValidateBinlogImportRequest(ctx context.Context, cm storage.ChunkManager, + reqFiles []*msgpb.ImportFile, options []*commonpb.KeyValuePair, +) error { + files := lo.Map(reqFiles, func(file *msgpb.ImportFile, _ int) *internalpb.ImportFile { + return &internalpb.ImportFile{Id: file.GetId(), Paths: file.GetPaths()} + }) + _, err := ListBinlogImportRequestFiles(ctx, cm, files, options) + return err +} + +// ListBinlogImportRequestFiles lists the binlog files from the request. +// TODO: dyh, remove listing binlog after backup-restore derectly passed the segments paths. +func ListBinlogImportRequestFiles(ctx context.Context, cm storage.ChunkManager, + reqFiles []*internalpb.ImportFile, options []*commonpb.KeyValuePair, +) ([]*internalpb.ImportFile, error) { + isBackup := importutilv2.IsBackup(options) + if !isBackup { + return reqFiles, nil + } + resFiles := make([]*internalpb.ImportFile, 0) + pool := conc.NewPool[struct{}](hardware.GetCPUNum() * 2) + defer pool.Release() + futures := make([]*conc.Future[struct{}], 0, len(reqFiles)) + mu := &sync.Mutex{} + for _, importFile := range reqFiles { + importFile := importFile + futures = append(futures, pool.Submit(func() (struct{}, error) { + segmentPrefixes, err := ListBinlogsAndGroupBySegment(ctx, cm, importFile) + if err != nil { + return struct{}{}, err + } + mu.Lock() + defer mu.Unlock() + resFiles = append(resFiles, segmentPrefixes...) + return struct{}{}, nil + })) + } + err := conc.AwaitAll(futures...) + if err != nil { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("list binlogs failed, err=%s", err)) + } + + resFiles = lo.Filter(resFiles, func(file *internalpb.ImportFile, _ int) bool { + return len(file.GetPaths()) > 0 + }) + if len(resFiles) == 0 { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("no binlog to import, input=%s", reqFiles)) + } + if len(resFiles) > paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt() { + return nil, merr.WrapErrImportFailed(fmt.Sprintf("The max number of import files should not exceed %d, but got %d", + paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt(), len(resFiles))) + } + log.Info("list binlogs prefixes for import done", zap.Int("num", len(resFiles)), zap.Any("binlog_prefixes", resFiles)) + return resFiles, nil +} + +// ValidateMaxImportJobExceed checks if the number of import jobs exceeds the limit. +func ValidateMaxImportJobExceed(ctx context.Context, importMeta ImportMeta) error { + maxNum := paramtable.Get().DataCoordCfg.MaxImportJobNum.GetAsInt() + executingNum := importMeta.CountJobBy(ctx, WithoutJobStates(internalpb.ImportJobState_Completed, internalpb.ImportJobState_Failed)) + if executingNum >= maxNum { + return merr.WrapErrImportFailed( + fmt.Sprintf("The number of jobs has reached the limit, please try again later. " + + "If your request is set to only import a single file, " + + "please consider importing multiple files in one request for better efficiency.")) + } + return nil +} diff --git a/internal/datacoord/import_util_test.go b/internal/datacoord/import_util_test.go index ae94ab468b..9326d7be33 100644 --- a/internal/datacoord/import_util_test.go +++ b/internal/datacoord/import_util_test.go @@ -21,6 +21,7 @@ import ( "fmt" "math/rand" "path" + "strings" "testing" "time" @@ -31,6 +32,7 @@ import ( "go.uber.org/atomic" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/broker" @@ -885,3 +887,233 @@ func TestImportTask_MarshalJSON(t *testing.T) { assert.Equal(t, task.GetCreatedTime(), importTask.CreatedTime) assert.Equal(t, task.GetCompleteTime(), importTask.CompleteTime) } + +// TestImportUtil_ValidateBinlogImportRequest tests the validation of binlog import request +func TestImportUtil_ValidateBinlogImportRequest(t *testing.T) { + ctx := context.Background() + mockCM := mocks2.NewChunkManager(t) + + t.Run("empty files", func(t *testing.T) { + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + err := ValidateBinlogImportRequest(ctx, mockCM, nil, options) + assert.Error(t, err) + }) + + t.Run("valid files - not backup", func(t *testing.T) { + files := []*msgpb.ImportFile{ + { + Id: 1, + Paths: []string{"path1"}, + }, + } + err := ValidateBinlogImportRequest(ctx, mockCM, files, nil) + assert.NoError(t, err) + }) + + t.Run("invalid files - too many paths", func(t *testing.T) { + files := []*msgpb.ImportFile{ + { + Id: 1, + Paths: []string{"path1", "path2", "path3"}, + }, + } + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + err := ValidateBinlogImportRequest(ctx, mockCM, files, options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "too many input paths") + }) +} + +// TestImportUtil_ListBinlogImportRequestFiles tests listing binlog files from import request +func TestImportUtil_ListBinlogImportRequestFiles(t *testing.T) { + ctx := context.Background() + + t.Run("empty files", func(t *testing.T) { + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + files, err := ListBinlogImportRequestFiles(ctx, nil, nil, options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no binlog to import") + assert.Nil(t, files) + }) + + t.Run("not backup files", func(t *testing.T) { + reqFiles := []*internalpb.ImportFile{ + { + Paths: []string{"path1"}, + }, + } + files, err := ListBinlogImportRequestFiles(ctx, nil, reqFiles, nil) + assert.NoError(t, err) + assert.Equal(t, reqFiles, files) + }) + + t.Run("backup files - list error", func(t *testing.T) { + reqFiles := []*internalpb.ImportFile{ + { + Paths: []string{"path1"}, + }, + } + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + mockCM := mocks2.NewChunkManager(t) + mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(errors.New("mock error")) + files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "list binlogs failed") + assert.Nil(t, files) + }) + + t.Run("backup files - success", func(t *testing.T) { + reqFiles := []*internalpb.ImportFile{ + { + Paths: []string{"path1"}, + }, + } + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + mockCM := mocks2.NewChunkManager(t) + mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error { + walkFunc(&storage.ChunkObjectInfo{ + FilePath: "path1", + }) + return nil + }) + files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options) + assert.NoError(t, err) + assert.Equal(t, 1, len(files)) + assert.Equal(t, "path1", files[0].GetPaths()[0]) + }) + + t.Run("backup files - empty result", func(t *testing.T) { + reqFiles := []*internalpb.ImportFile{ + { + Paths: []string{"path1"}, + }, + } + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + mockCM := mocks2.NewChunkManager(t) + mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error { + return nil + }) + files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no binlog to import") + assert.Nil(t, files) + }) + + t.Run("backup files - too many files", func(t *testing.T) { + maxFiles := paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt() + reqFiles := make([]*internalpb.ImportFile, maxFiles+1) + for i := 0; i < maxFiles+1; i++ { + reqFiles[i] = &internalpb.ImportFile{ + Paths: []string{fmt.Sprintf("path%d", i)}, + } + } + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + mockCM := mocks2.NewChunkManager(t) + mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error { + for i := 0; i < maxFiles+1; i++ { + walkFunc(&storage.ChunkObjectInfo{ + FilePath: fmt.Sprintf("path%d", i), + }) + } + return nil + }) + files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options) + assert.Error(t, err) + assert.Contains(t, err.Error(), fmt.Sprintf("The max number of import files should not exceed %d", maxFiles)) + assert.Nil(t, files) + }) + + t.Run("backup files - multiple files with delta", func(t *testing.T) { + reqFiles := []*internalpb.ImportFile{ + { + Paths: []string{"insert/path1", "delta/path1"}, + }, + } + options := []*commonpb.KeyValuePair{ + { + Key: importutilv2.BackupFlag, + Value: "true", + }, + } + mockCM := mocks2.NewChunkManager(t) + mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn( + func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error { + if strings.Contains(prefix, "insert") { + walkFunc(&storage.ChunkObjectInfo{ + FilePath: "insert/path1", + }) + } else if strings.Contains(prefix, "delta") { + walkFunc(&storage.ChunkObjectInfo{ + FilePath: "delta/path1", + }) + } + return nil + }).Times(2) + files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options) + assert.NoError(t, err) + assert.Equal(t, 1, len(files)) + assert.Equal(t, 2, len(files[0].GetPaths())) + assert.Equal(t, "insert/path1", files[0].GetPaths()[0]) + assert.Equal(t, "delta/path1", files[0].GetPaths()[1]) + }) +} + +// TestImportUtil_ValidateMaxImportJobExceed tests validation of maximum import jobs +func TestImportUtil_ValidateMaxImportJobExceed(t *testing.T) { + ctx := context.Background() + + t.Run("job count within limit", func(t *testing.T) { + mockImportMeta := NewMockImportMeta(t) + mockImportMeta.EXPECT().CountJobBy(mock.Anything, mock.Anything).Return(1) + err := ValidateMaxImportJobExceed(ctx, mockImportMeta) + assert.NoError(t, err) + }) + + t.Run("job count exceeds limit", func(t *testing.T) { + mockImportMeta := NewMockImportMeta(t) + mockImportMeta.EXPECT().CountJobBy(mock.Anything, mock.Anything). + Return(paramtable.Get().DataCoordCfg.MaxImportJobNum.GetAsInt() + 1) + err := ValidateMaxImportJobExceed(ctx, mockImportMeta) + assert.Error(t, err) + assert.Contains(t, err.Error(), "The number of jobs has reached the limit") + }) +} diff --git a/internal/datacoord/mock_import_meta.go b/internal/datacoord/mock_import_meta.go new file mode 100644 index 0000000000..3970547795 --- /dev/null +++ b/internal/datacoord/mock_import_meta.go @@ -0,0 +1,679 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package datacoord + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" +) + +// MockImportMeta is an autogenerated mock type for the ImportMeta type +type MockImportMeta struct { + mock.Mock +} + +type MockImportMeta_Expecter struct { + mock *mock.Mock +} + +func (_m *MockImportMeta) EXPECT() *MockImportMeta_Expecter { + return &MockImportMeta_Expecter{mock: &_m.Mock} +} + +// AddJob provides a mock function with given fields: ctx, job +func (_m *MockImportMeta) AddJob(ctx context.Context, job ImportJob) error { + ret := _m.Called(ctx, job) + + if len(ret) == 0 { + panic("no return value specified for AddJob") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, ImportJob) error); ok { + r0 = rf(ctx, job) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockImportMeta_AddJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddJob' +type MockImportMeta_AddJob_Call struct { + *mock.Call +} + +// AddJob is a helper method to define mock.On call +// - ctx context.Context +// - job ImportJob +func (_e *MockImportMeta_Expecter) AddJob(ctx interface{}, job interface{}) *MockImportMeta_AddJob_Call { + return &MockImportMeta_AddJob_Call{Call: _e.mock.On("AddJob", ctx, job)} +} + +func (_c *MockImportMeta_AddJob_Call) Run(run func(ctx context.Context, job ImportJob)) *MockImportMeta_AddJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(ImportJob)) + }) + return _c +} + +func (_c *MockImportMeta_AddJob_Call) Return(_a0 error) *MockImportMeta_AddJob_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_AddJob_Call) RunAndReturn(run func(context.Context, ImportJob) error) *MockImportMeta_AddJob_Call { + _c.Call.Return(run) + return _c +} + +// AddTask provides a mock function with given fields: ctx, task +func (_m *MockImportMeta) AddTask(ctx context.Context, task ImportTask) error { + ret := _m.Called(ctx, task) + + if len(ret) == 0 { + panic("no return value specified for AddTask") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, ImportTask) error); ok { + r0 = rf(ctx, task) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockImportMeta_AddTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTask' +type MockImportMeta_AddTask_Call struct { + *mock.Call +} + +// AddTask is a helper method to define mock.On call +// - ctx context.Context +// - task ImportTask +func (_e *MockImportMeta_Expecter) AddTask(ctx interface{}, task interface{}) *MockImportMeta_AddTask_Call { + return &MockImportMeta_AddTask_Call{Call: _e.mock.On("AddTask", ctx, task)} +} + +func (_c *MockImportMeta_AddTask_Call) Run(run func(ctx context.Context, task ImportTask)) *MockImportMeta_AddTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(ImportTask)) + }) + return _c +} + +func (_c *MockImportMeta_AddTask_Call) Return(_a0 error) *MockImportMeta_AddTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_AddTask_Call) RunAndReturn(run func(context.Context, ImportTask) error) *MockImportMeta_AddTask_Call { + _c.Call.Return(run) + return _c +} + +// CountJobBy provides a mock function with given fields: ctx, filters +func (_m *MockImportMeta) CountJobBy(ctx context.Context, filters ...ImportJobFilter) int { + _va := make([]interface{}, len(filters)) + for _i := range filters { + _va[_i] = filters[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for CountJobBy") + } + + var r0 int + if rf, ok := ret.Get(0).(func(context.Context, ...ImportJobFilter) int); ok { + r0 = rf(ctx, filters...) + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// MockImportMeta_CountJobBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CountJobBy' +type MockImportMeta_CountJobBy_Call struct { + *mock.Call +} + +// CountJobBy is a helper method to define mock.On call +// - ctx context.Context +// - filters ...ImportJobFilter +func (_e *MockImportMeta_Expecter) CountJobBy(ctx interface{}, filters ...interface{}) *MockImportMeta_CountJobBy_Call { + return &MockImportMeta_CountJobBy_Call{Call: _e.mock.On("CountJobBy", + append([]interface{}{ctx}, filters...)...)} +} + +func (_c *MockImportMeta_CountJobBy_Call) Run(run func(ctx context.Context, filters ...ImportJobFilter)) *MockImportMeta_CountJobBy_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]ImportJobFilter, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(ImportJobFilter) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockImportMeta_CountJobBy_Call) Return(_a0 int) *MockImportMeta_CountJobBy_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_CountJobBy_Call) RunAndReturn(run func(context.Context, ...ImportJobFilter) int) *MockImportMeta_CountJobBy_Call { + _c.Call.Return(run) + return _c +} + +// GetJob provides a mock function with given fields: ctx, jobID +func (_m *MockImportMeta) GetJob(ctx context.Context, jobID int64) ImportJob { + ret := _m.Called(ctx, jobID) + + if len(ret) == 0 { + panic("no return value specified for GetJob") + } + + var r0 ImportJob + if rf, ok := ret.Get(0).(func(context.Context, int64) ImportJob); ok { + r0 = rf(ctx, jobID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(ImportJob) + } + } + + return r0 +} + +// MockImportMeta_GetJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetJob' +type MockImportMeta_GetJob_Call struct { + *mock.Call +} + +// GetJob is a helper method to define mock.On call +// - ctx context.Context +// - jobID int64 +func (_e *MockImportMeta_Expecter) GetJob(ctx interface{}, jobID interface{}) *MockImportMeta_GetJob_Call { + return &MockImportMeta_GetJob_Call{Call: _e.mock.On("GetJob", ctx, jobID)} +} + +func (_c *MockImportMeta_GetJob_Call) Run(run func(ctx context.Context, jobID int64)) *MockImportMeta_GetJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockImportMeta_GetJob_Call) Return(_a0 ImportJob) *MockImportMeta_GetJob_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_GetJob_Call) RunAndReturn(run func(context.Context, int64) ImportJob) *MockImportMeta_GetJob_Call { + _c.Call.Return(run) + return _c +} + +// GetJobBy provides a mock function with given fields: ctx, filters +func (_m *MockImportMeta) GetJobBy(ctx context.Context, filters ...ImportJobFilter) []ImportJob { + _va := make([]interface{}, len(filters)) + for _i := range filters { + _va[_i] = filters[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetJobBy") + } + + var r0 []ImportJob + if rf, ok := ret.Get(0).(func(context.Context, ...ImportJobFilter) []ImportJob); ok { + r0 = rf(ctx, filters...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]ImportJob) + } + } + + return r0 +} + +// MockImportMeta_GetJobBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetJobBy' +type MockImportMeta_GetJobBy_Call struct { + *mock.Call +} + +// GetJobBy is a helper method to define mock.On call +// - ctx context.Context +// - filters ...ImportJobFilter +func (_e *MockImportMeta_Expecter) GetJobBy(ctx interface{}, filters ...interface{}) *MockImportMeta_GetJobBy_Call { + return &MockImportMeta_GetJobBy_Call{Call: _e.mock.On("GetJobBy", + append([]interface{}{ctx}, filters...)...)} +} + +func (_c *MockImportMeta_GetJobBy_Call) Run(run func(ctx context.Context, filters ...ImportJobFilter)) *MockImportMeta_GetJobBy_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]ImportJobFilter, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(ImportJobFilter) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockImportMeta_GetJobBy_Call) Return(_a0 []ImportJob) *MockImportMeta_GetJobBy_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_GetJobBy_Call) RunAndReturn(run func(context.Context, ...ImportJobFilter) []ImportJob) *MockImportMeta_GetJobBy_Call { + _c.Call.Return(run) + return _c +} + +// GetTask provides a mock function with given fields: ctx, taskID +func (_m *MockImportMeta) GetTask(ctx context.Context, taskID int64) ImportTask { + ret := _m.Called(ctx, taskID) + + if len(ret) == 0 { + panic("no return value specified for GetTask") + } + + var r0 ImportTask + if rf, ok := ret.Get(0).(func(context.Context, int64) ImportTask); ok { + r0 = rf(ctx, taskID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(ImportTask) + } + } + + return r0 +} + +// MockImportMeta_GetTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTask' +type MockImportMeta_GetTask_Call struct { + *mock.Call +} + +// GetTask is a helper method to define mock.On call +// - ctx context.Context +// - taskID int64 +func (_e *MockImportMeta_Expecter) GetTask(ctx interface{}, taskID interface{}) *MockImportMeta_GetTask_Call { + return &MockImportMeta_GetTask_Call{Call: _e.mock.On("GetTask", ctx, taskID)} +} + +func (_c *MockImportMeta_GetTask_Call) Run(run func(ctx context.Context, taskID int64)) *MockImportMeta_GetTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockImportMeta_GetTask_Call) Return(_a0 ImportTask) *MockImportMeta_GetTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_GetTask_Call) RunAndReturn(run func(context.Context, int64) ImportTask) *MockImportMeta_GetTask_Call { + _c.Call.Return(run) + return _c +} + +// GetTaskBy provides a mock function with given fields: ctx, filters +func (_m *MockImportMeta) GetTaskBy(ctx context.Context, filters ...ImportTaskFilter) []ImportTask { + _va := make([]interface{}, len(filters)) + for _i := range filters { + _va[_i] = filters[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for GetTaskBy") + } + + var r0 []ImportTask + if rf, ok := ret.Get(0).(func(context.Context, ...ImportTaskFilter) []ImportTask); ok { + r0 = rf(ctx, filters...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]ImportTask) + } + } + + return r0 +} + +// MockImportMeta_GetTaskBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTaskBy' +type MockImportMeta_GetTaskBy_Call struct { + *mock.Call +} + +// GetTaskBy is a helper method to define mock.On call +// - ctx context.Context +// - filters ...ImportTaskFilter +func (_e *MockImportMeta_Expecter) GetTaskBy(ctx interface{}, filters ...interface{}) *MockImportMeta_GetTaskBy_Call { + return &MockImportMeta_GetTaskBy_Call{Call: _e.mock.On("GetTaskBy", + append([]interface{}{ctx}, filters...)...)} +} + +func (_c *MockImportMeta_GetTaskBy_Call) Run(run func(ctx context.Context, filters ...ImportTaskFilter)) *MockImportMeta_GetTaskBy_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]ImportTaskFilter, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(ImportTaskFilter) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *MockImportMeta_GetTaskBy_Call) Return(_a0 []ImportTask) *MockImportMeta_GetTaskBy_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_GetTaskBy_Call) RunAndReturn(run func(context.Context, ...ImportTaskFilter) []ImportTask) *MockImportMeta_GetTaskBy_Call { + _c.Call.Return(run) + return _c +} + +// RemoveJob provides a mock function with given fields: ctx, jobID +func (_m *MockImportMeta) RemoveJob(ctx context.Context, jobID int64) error { + ret := _m.Called(ctx, jobID) + + if len(ret) == 0 { + panic("no return value specified for RemoveJob") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, jobID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockImportMeta_RemoveJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveJob' +type MockImportMeta_RemoveJob_Call struct { + *mock.Call +} + +// RemoveJob is a helper method to define mock.On call +// - ctx context.Context +// - jobID int64 +func (_e *MockImportMeta_Expecter) RemoveJob(ctx interface{}, jobID interface{}) *MockImportMeta_RemoveJob_Call { + return &MockImportMeta_RemoveJob_Call{Call: _e.mock.On("RemoveJob", ctx, jobID)} +} + +func (_c *MockImportMeta_RemoveJob_Call) Run(run func(ctx context.Context, jobID int64)) *MockImportMeta_RemoveJob_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockImportMeta_RemoveJob_Call) Return(_a0 error) *MockImportMeta_RemoveJob_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_RemoveJob_Call) RunAndReturn(run func(context.Context, int64) error) *MockImportMeta_RemoveJob_Call { + _c.Call.Return(run) + return _c +} + +// RemoveTask provides a mock function with given fields: ctx, taskID +func (_m *MockImportMeta) RemoveTask(ctx context.Context, taskID int64) error { + ret := _m.Called(ctx, taskID) + + if len(ret) == 0 { + panic("no return value specified for RemoveTask") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, taskID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockImportMeta_RemoveTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveTask' +type MockImportMeta_RemoveTask_Call struct { + *mock.Call +} + +// RemoveTask is a helper method to define mock.On call +// - ctx context.Context +// - taskID int64 +func (_e *MockImportMeta_Expecter) RemoveTask(ctx interface{}, taskID interface{}) *MockImportMeta_RemoveTask_Call { + return &MockImportMeta_RemoveTask_Call{Call: _e.mock.On("RemoveTask", ctx, taskID)} +} + +func (_c *MockImportMeta_RemoveTask_Call) Run(run func(ctx context.Context, taskID int64)) *MockImportMeta_RemoveTask_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(int64)) + }) + return _c +} + +func (_c *MockImportMeta_RemoveTask_Call) Return(_a0 error) *MockImportMeta_RemoveTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_RemoveTask_Call) RunAndReturn(run func(context.Context, int64) error) *MockImportMeta_RemoveTask_Call { + _c.Call.Return(run) + return _c +} + +// TaskStatsJSON provides a mock function with given fields: ctx +func (_m *MockImportMeta) TaskStatsJSON(ctx context.Context) string { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for TaskStatsJSON") + } + + var r0 string + if rf, ok := ret.Get(0).(func(context.Context) string); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(string) + } + + return r0 +} + +// MockImportMeta_TaskStatsJSON_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TaskStatsJSON' +type MockImportMeta_TaskStatsJSON_Call struct { + *mock.Call +} + +// TaskStatsJSON is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockImportMeta_Expecter) TaskStatsJSON(ctx interface{}) *MockImportMeta_TaskStatsJSON_Call { + return &MockImportMeta_TaskStatsJSON_Call{Call: _e.mock.On("TaskStatsJSON", ctx)} +} + +func (_c *MockImportMeta_TaskStatsJSON_Call) Run(run func(ctx context.Context)) *MockImportMeta_TaskStatsJSON_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockImportMeta_TaskStatsJSON_Call) Return(_a0 string) *MockImportMeta_TaskStatsJSON_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_TaskStatsJSON_Call) RunAndReturn(run func(context.Context) string) *MockImportMeta_TaskStatsJSON_Call { + _c.Call.Return(run) + return _c +} + +// UpdateJob provides a mock function with given fields: ctx, jobID, actions +func (_m *MockImportMeta) UpdateJob(ctx context.Context, jobID int64, actions ...UpdateJobAction) error { + _va := make([]interface{}, len(actions)) + for _i := range actions { + _va[_i] = actions[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, jobID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for UpdateJob") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, ...UpdateJobAction) error); ok { + r0 = rf(ctx, jobID, actions...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockImportMeta_UpdateJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateJob' +type MockImportMeta_UpdateJob_Call struct { + *mock.Call +} + +// UpdateJob is a helper method to define mock.On call +// - ctx context.Context +// - jobID int64 +// - actions ...UpdateJobAction +func (_e *MockImportMeta_Expecter) UpdateJob(ctx interface{}, jobID interface{}, actions ...interface{}) *MockImportMeta_UpdateJob_Call { + return &MockImportMeta_UpdateJob_Call{Call: _e.mock.On("UpdateJob", + append([]interface{}{ctx, jobID}, actions...)...)} +} + +func (_c *MockImportMeta_UpdateJob_Call) Run(run func(ctx context.Context, jobID int64, actions ...UpdateJobAction)) *MockImportMeta_UpdateJob_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]UpdateJobAction, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(UpdateJobAction) + } + } + run(args[0].(context.Context), args[1].(int64), variadicArgs...) + }) + return _c +} + +func (_c *MockImportMeta_UpdateJob_Call) Return(_a0 error) *MockImportMeta_UpdateJob_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_UpdateJob_Call) RunAndReturn(run func(context.Context, int64, ...UpdateJobAction) error) *MockImportMeta_UpdateJob_Call { + _c.Call.Return(run) + return _c +} + +// UpdateTask provides a mock function with given fields: ctx, taskID, actions +func (_m *MockImportMeta) UpdateTask(ctx context.Context, taskID int64, actions ...UpdateAction) error { + _va := make([]interface{}, len(actions)) + for _i := range actions { + _va[_i] = actions[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, taskID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + if len(ret) == 0 { + panic("no return value specified for UpdateTask") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64, ...UpdateAction) error); ok { + r0 = rf(ctx, taskID, actions...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockImportMeta_UpdateTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTask' +type MockImportMeta_UpdateTask_Call struct { + *mock.Call +} + +// UpdateTask is a helper method to define mock.On call +// - ctx context.Context +// - taskID int64 +// - actions ...UpdateAction +func (_e *MockImportMeta_Expecter) UpdateTask(ctx interface{}, taskID interface{}, actions ...interface{}) *MockImportMeta_UpdateTask_Call { + return &MockImportMeta_UpdateTask_Call{Call: _e.mock.On("UpdateTask", + append([]interface{}{ctx, taskID}, actions...)...)} +} + +func (_c *MockImportMeta_UpdateTask_Call) Run(run func(ctx context.Context, taskID int64, actions ...UpdateAction)) *MockImportMeta_UpdateTask_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]UpdateAction, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(UpdateAction) + } + } + run(args[0].(context.Context), args[1].(int64), variadicArgs...) + }) + return _c +} + +func (_c *MockImportMeta_UpdateTask_Call) Return(_a0 error) *MockImportMeta_UpdateTask_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockImportMeta_UpdateTask_Call) RunAndReturn(run func(context.Context, int64, ...UpdateAction) error) *MockImportMeta_UpdateTask_Call { + _c.Call.Return(run) + return _c +} + +// NewMockImportMeta creates a new instance of MockImportMeta. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockImportMeta(t interface { + mock.TestingT + Cleanup(func()) +}) *MockImportMeta { + mock := &MockImportMeta{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/internal/datacoord/server.go b/internal/datacoord/server.go index db981d5893..6f22c62cbd 100644 --- a/internal/datacoord/server.go +++ b/internal/datacoord/server.go @@ -36,6 +36,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" globalIDAllocator "github.com/milvus-io/milvus/internal/allocator" "github.com/milvus-io/milvus/internal/datacoord/allocator" "github.com/milvus-io/milvus/internal/datacoord/broker" @@ -49,15 +50,18 @@ import ( "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" + "github.com/milvus-io/milvus/internal/util/importutilv2" "github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/streamingutil" "github.com/milvus-io/milvus/pkg/v2/kv" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" + "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/util" "github.com/milvus-io/milvus/pkg/v2/util/expr" + "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/logutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" @@ -331,17 +335,71 @@ func (s *Server) initDataCoord() error { log.Info("init datacoord done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", s.address)) - s.initMessageAckCallback() + s.initMessageCallback() return nil } -// initMessageAckCallback initializes the message ack callback. +// initMessageCallback initializes the message callback. // TODO: we should build a ddl framework to handle the message ack callback for ddl messages -func (s *Server) initMessageAckCallback() { +func (s *Server) initMessageCallback() { registry.RegisterMessageAckCallback(message.MessageTypeDropPartition, func(ctx context.Context, msg message.MutableMessage) error { dropPartitionMsg := message.MustAsMutableDropPartitionMessageV1(msg) return s.NotifyDropPartition(ctx, msg.VChannel(), []int64{dropPartitionMsg.Header().PartitionId}) }) + + registry.RegisterMessageAckCallback(message.MessageTypeImport, func(ctx context.Context, msg message.MutableMessage) error { + importMsg := message.MustAsMutableImportMessageV1(msg) + body := importMsg.MustBody() + importResp, err := s.ImportV2(ctx, &internalpb.ImportRequestInternal{ + CollectionID: body.GetCollectionID(), + CollectionName: body.GetCollectionName(), + PartitionIDs: body.GetPartitionIDs(), + ChannelNames: []string{msg.VChannel()}, + Schema: body.GetSchema(), + Files: lo.Map(body.GetFiles(), func(file *msgpb.ImportFile, _ int) *internalpb.ImportFile { + return &internalpb.ImportFile{ + Id: file.GetId(), + Paths: file.GetPaths(), + } + }), + Options: funcutil.Map2KeyValuePair(body.GetOptions()), + DataTimestamp: body.GetBase().GetTimestamp(), + JobID: body.GetJobID(), + }) + err = merr.CheckRPCCall(importResp, err) + if errors.Is(err, merr.ErrCollectionNotFound) { + log.Ctx(ctx).Warn("import message failed because of collection not found, skip it", zap.String("job_id", importResp.GetJobID()), zap.Error(err)) + return nil + } + if err != nil { + log.Ctx(ctx).Warn("import message failed", zap.String("job_id", importResp.GetJobID()), zap.Error(err)) + return err + } + log.Ctx(ctx).Info("import message handled", zap.String("job_id", importResp.GetJobID())) + return nil + }) + + registry.RegisterMessageCheckCallback(message.MessageTypeImport, func(ctx context.Context, msg message.BroadcastMutableMessage) error { + importMsg := message.MustAsMutableImportMessageV1(msg) + b, err := importMsg.Body() + if err != nil { + return err + } + options := funcutil.Map2KeyValuePair(b.GetOptions()) + _, err = importutilv2.GetTimeoutTs(options) + if err != nil { + return err + } + err = ValidateBinlogImportRequest(ctx, s.meta.chunkManager, b.GetFiles(), options) + if err != nil { + return err + } + err = ValidateMaxImportJobExceed(ctx, s.importMeta) + if err != nil { + return err + } + return nil + }) } // Start initialize `Server` members and start loops, follow steps are taken: diff --git a/internal/datacoord/server_test.go b/internal/datacoord/server_test.go index 248b17ec1c..6cb0dcbc51 100644 --- a/internal/datacoord/server_test.go +++ b/internal/datacoord/server_test.go @@ -46,8 +46,10 @@ import ( "github.com/milvus-io/milvus/internal/datacoord/broker" "github.com/milvus-io/milvus/internal/datacoord/session" etcdkv "github.com/milvus-io/milvus/internal/kv/etcd" + mocks2 "github.com/milvus-io/milvus/internal/metastore/mocks" "github.com/milvus-io/milvus/internal/metastore/model" "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/sessionutil" @@ -57,6 +59,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/proto/indexpb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/workerpb" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" "github.com/milvus-io/milvus/pkg/v2/util/etcd" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" @@ -2698,3 +2701,71 @@ func TestUpdateAutoBalanceConfigLoop(t *testing.T) { wg.Wait() }) } + +func TestServer_InitMessageCallback(t *testing.T) { + ctx := context.Background() + + mockCatalog := mocks2.NewDataCoordCatalog(t) + mockChunkManager := mocks.NewChunkManager(t) + mockManager := NewMockManager(t) + + server := &Server{ + ctx: ctx, + meta: &meta{ + catalog: mockCatalog, + chunkManager: mockChunkManager, + segments: NewSegmentsInfo(), + }, + importMeta: &importMeta{}, + segmentManager: mockManager, + } + server.stateCode.Store(commonpb.StateCode_Abnormal) + + // Test initMessageCallback + server.initMessageCallback() + + // Test DropPartition message callback + dropPartitionMsg, err := message.NewDropPartitionMessageBuilderV1(). + WithVChannel("test_channel"). + WithHeader(&message.DropPartitionMessageHeader{ + CollectionId: 1, + PartitionId: 1, + }). + WithBody(&msgpb.DropPartitionRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_DropPartition, + }, + }). + BuildMutable() + assert.NoError(t, err) + err = registry.CallMessageAckCallback(ctx, dropPartitionMsg) + assert.Error(t, err) // server not healthy + + // Test Import message check callback + resourceKey := message.NewImportJobIDResourceKey(1) + msg, err := message.NewImportMessageBuilderV1(). + WithHeader(&message.ImportMessageHeader{}). + WithBody(&msgpb.ImportMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Import, + }, + }). + WithBroadcast([]string{"ch-0"}, resourceKey). + BuildBroadcast() + err = registry.CallMessageCheckCallback(ctx, msg) + assert.NoError(t, err) + + // Test Import message ack callback + importMsg, err := message.NewImportMessageBuilderV1(). + WithVChannel("test_channel"). + WithHeader(&message.ImportMessageHeader{}). + WithBody(&msgpb.ImportMsg{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_Import, + }, + }). + BuildMutable() + assert.NoError(t, err) + err = registry.CallMessageAckCallback(ctx, importMsg) + assert.Error(t, err) // server not healthy +} diff --git a/internal/datacoord/services.go b/internal/datacoord/services.go index 9af83818b5..d2ccf944ec 100644 --- a/internal/datacoord/services.go +++ b/internal/datacoord/services.go @@ -21,7 +21,6 @@ import ( "fmt" "math" "strconv" - "sync" "time" "github.com/cockroachdb/errors" @@ -44,9 +43,7 @@ import ( "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" - "github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" - "github.com/milvus-io/milvus/pkg/v2/util/hardware" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metricsinfo" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" @@ -1796,60 +1793,13 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter files := in.GetFiles() isBackup := importutilv2.IsBackup(in.GetOptions()) if isBackup { - files = make([]*internalpb.ImportFile, 0) - pool := conc.NewPool[struct{}](hardware.GetCPUNum() * 2) - defer pool.Release() - futures := make([]*conc.Future[struct{}], 0, len(in.GetFiles())) - mu := &sync.Mutex{} - for _, importFile := range in.GetFiles() { - importFile := importFile - futures = append(futures, pool.Submit(func() (struct{}, error) { - segmentPrefixes, err := ListBinlogsAndGroupBySegment(ctx, s.meta.chunkManager, importFile) - if err != nil { - return struct{}{}, err - } - mu.Lock() - defer mu.Unlock() - files = append(files, segmentPrefixes...) - return struct{}{}, nil - })) - } - err = conc.AwaitAll(futures...) + files, err = ListBinlogImportRequestFiles(ctx, s.meta.chunkManager, files, in.GetOptions()) if err != nil { - resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("list binlogs failed, err=%s", err))) + resp.Status = merr.Status(err) return resp, nil } - - files = lo.Filter(files, func(file *internalpb.ImportFile, _ int) bool { - return len(file.GetPaths()) > 0 - }) - if len(files) == 0 { - resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("no binlog to import, input=%s", in.GetFiles()))) - return resp, nil - } - if len(files) > paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt() { - resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("The max number of import files should not exceed %d, but got %d", - paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt(), len(files)))) - return resp, nil - } - log.Info("list binlogs prefixes for import", zap.Int("num", len(files)), zap.Any("binlog_prefixes", files)) } - // The import task does not need to be controlled for the time being, and additional development is required later. - // Here is a comment, because the current importv2 communicates through messages and needs to ensure idempotence. - // Adding this part of the logic will cause importv2 to retry infinitely until the previous import task is completed. - - // Check if the number of jobs exceeds the limit. - // maxNum := paramtable.Get().DataCoordCfg.MaxImportJobNum.GetAsInt() - // executingNum := s.importMeta.CountJobBy(ctx, WithoutJobStates(internalpb.ImportJobState_Completed, internalpb.ImportJobState_Failed)) - // if executingNum >= maxNum { - // resp.Status = merr.Status(merr.WrapErrImportFailed( - // fmt.Sprintf("The number of jobs has reached the limit, please try again later. " + - // "If your request is set to only import a single file, " + - // "please consider importing multiple files in one request for better efficiency."))) - // return resp, nil - // } - // Allocate file ids. idStart, _, err := s.allocator.AllocN(int64(len(files)) + 1) if err != nil { diff --git a/internal/datacoord/services_test.go b/internal/datacoord/services_test.go index 6de066d793..087dadbcf9 100644 --- a/internal/datacoord/services_test.go +++ b/internal/datacoord/services_test.go @@ -1404,13 +1404,6 @@ func TestImportV2(t *testing.T) { assert.Equal(t, int32(0), resp.GetStatus().GetCode()) jobs = s.importMeta.GetJobBy(context.TODO()) assert.Equal(t, 1, len(jobs)) - - // number of jobs reached the limit - // Params.Save(paramtable.Get().DataCoordCfg.MaxImportJobNum.Key, "1") - // resp, err = s.ImportV2(ctx, &internalpb.ImportRequestInternal{}) - // assert.NoError(t, err) - // assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed)) - // Params.Reset(paramtable.Get().DataCoordCfg.MaxImportJobNum.Key) }) t.Run("GetImportProgress", func(t *testing.T) { diff --git a/internal/datanode/msghandlerimpl/msg_handler_impl.go b/internal/datanode/msghandlerimpl/msg_handler_impl.go index f45ccefe8d..101311e872 100644 --- a/internal/datanode/msghandlerimpl/msg_handler_impl.go +++ b/internal/datanode/msghandlerimpl/msg_handler_impl.go @@ -21,21 +21,8 @@ package msghandlerimpl import ( "context" - "github.com/cockroachdb/errors" - "github.com/samber/lo" - "go.uber.org/zap" - - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/flushcommon/broker" - "github.com/milvus-io/milvus/internal/flushcommon/util" - "github.com/milvus-io/milvus/pkg/v2/log" - "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/streaming/util/types" - "github.com/milvus-io/milvus/pkg/v2/util/funcutil" - "github.com/milvus-io/milvus/pkg/v2/util/merr" - "github.com/milvus-io/milvus/pkg/v2/util/retry" ) type msgHandlerImpl struct { @@ -54,41 +41,6 @@ func (m *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFlush panic("unreachable code") } -func (m *msgHandlerImpl) HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error { - return retry.Do(ctx, func() (err error) { - defer func() { - if err == nil { - err = streaming.WAL().Broadcast().Ack(ctx, types.BroadcastAckRequest{ - BroadcastID: uint64(importMsg.GetJobID()), - VChannel: vchannel, - }) - } - }() - importResp, err := m.broker.ImportV2(ctx, &internalpb.ImportRequestInternal{ - CollectionID: importMsg.GetCollectionID(), - CollectionName: importMsg.GetCollectionName(), - PartitionIDs: importMsg.GetPartitionIDs(), - ChannelNames: []string{vchannel}, - Schema: importMsg.GetSchema(), - Files: lo.Map(importMsg.GetFiles(), util.ConvertInternalImportFile), - Options: funcutil.Map2KeyValuePair(importMsg.GetOptions()), - DataTimestamp: importMsg.GetBase().GetTimestamp(), - JobID: importMsg.GetJobID(), - }) - err = merr.CheckRPCCall(importResp, err) - if errors.Is(err, merr.ErrCollectionNotFound) { - log.Ctx(ctx).Warn("import message failed because of collection not found, skip it", zap.String("job_id", importResp.GetJobID()), zap.Error(err)) - return nil - } - if err != nil { - log.Ctx(ctx).Warn("import message failed", zap.String("job_id", importResp.GetJobID()), zap.Error(err)) - return err - } - log.Ctx(ctx).Info("import message handled", zap.String("job_id", importResp.GetJobID())) - return nil - }, retry.AttemptAlways()) -} - func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error { panic("unreachable code") } diff --git a/internal/datanode/msghandlerimpl/msg_handler_impl_test.go b/internal/datanode/msghandlerimpl/msg_handler_impl_test.go index e9db0c1c6a..90170e2b44 100644 --- a/internal/datanode/msghandlerimpl/msg_handler_impl_test.go +++ b/internal/datanode/msghandlerimpl/msg_handler_impl_test.go @@ -19,21 +19,16 @@ package msghandlerimpl import ( - "context" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/milvus-io/milvus/internal/distributed/streaming" "github.com/milvus-io/milvus/internal/flushcommon/broker" - "github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" ) func TestMsgHandlerImpl(t *testing.T) { paramtable.Init() - ctx := context.Background() b := broker.NewMockBroker(t) m := NewMsgHandlerImpl(b) assert.Panics(t, func() { @@ -45,17 +40,4 @@ func TestMsgHandlerImpl(t *testing.T) { assert.Panics(t, func() { m.HandleManualFlush(nil) }) - t.Run("HandleImport success", func(t *testing.T) { - wal := mock_streaming.NewMockWALAccesser(t) - bo := mock_streaming.NewMockBroadcast(t) - wal.EXPECT().Broadcast().Return(bo) - bo.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil) - streaming.SetWALForTest(wal) - defer streaming.RecoverWALForTest() - - b.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(nil, assert.AnError).Once() - b.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(nil, nil).Once() - err := m.HandleImport(ctx, "", nil) - assert.NoError(t, err) - }) } diff --git a/internal/flushcommon/pipeline/flow_graph_dd_node.go b/internal/flushcommon/pipeline/flow_graph_dd_node.go index 2171e31469..8b7c9515b7 100644 --- a/internal/flushcommon/pipeline/flow_graph_dd_node.go +++ b/internal/flushcommon/pipeline/flow_graph_dd_node.go @@ -282,21 +282,6 @@ func (ddn *ddNode) Operate(in []Msg) []Msg { } else { logger.Info("handle manual flush message success") } - case commonpb.MsgType_Import: - importMsg := msg.(*msgstream.ImportMsg) - if importMsg.GetCollectionID() != ddn.collectionID { - continue - } - logger := log.With( - zap.String("vchannel", ddn.Name()), - zap.Int32("msgType", int32(msg.Type())), - ) - logger.Info("receive import message") - if err := ddn.msgHandler.HandleImport(context.Background(), ddn.vChannelName, importMsg.ImportMsg); err != nil { - logger.Warn("handle import message failed", zap.Error(err)) - } else { - logger.Info("handle import message success") - } case commonpb.MsgType_AddCollectionField: schemaMsg := msg.(*adaptor.SchemaChangeMessageBody) header := schemaMsg.SchemaChangeMessage.Header() diff --git a/internal/flushcommon/util/msg_handler.go b/internal/flushcommon/util/msg_handler.go index 9ab333efb5..80a07c1196 100644 --- a/internal/flushcommon/util/msg_handler.go +++ b/internal/flushcommon/util/msg_handler.go @@ -31,8 +31,6 @@ type MsgHandler interface { HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error - HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error - HandleSchemaChange(ctx context.Context, schemaChangeMsg message.ImmutableSchemaChangeMessageV2) error } diff --git a/internal/mocks/flushcommon/mock_util/mock_MsgHandler.go b/internal/mocks/flushcommon/mock_util/mock_MsgHandler.go index 6d1a4265c1..d8e0720d5d 100644 --- a/internal/mocks/flushcommon/mock_util/mock_MsgHandler.go +++ b/internal/mocks/flushcommon/mock_util/mock_MsgHandler.go @@ -7,8 +7,6 @@ import ( message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" mock "github.com/stretchr/testify/mock" - - msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" ) // MockMsgHandler is an autogenerated mock type for the MsgHandler type @@ -117,54 +115,6 @@ func (_c *MockMsgHandler_HandleFlush_Call) RunAndReturn(run func(message.Immutab return _c } -// HandleImport provides a mock function with given fields: ctx, vchannel, importMsg -func (_m *MockMsgHandler) HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error { - ret := _m.Called(ctx, vchannel, importMsg) - - if len(ret) == 0 { - panic("no return value specified for HandleImport") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.ImportMsg) error); ok { - r0 = rf(ctx, vchannel, importMsg) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockMsgHandler_HandleImport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HandleImport' -type MockMsgHandler_HandleImport_Call struct { - *mock.Call -} - -// HandleImport is a helper method to define mock.On call -// - ctx context.Context -// - vchannel string -// - importMsg *msgpb.ImportMsg -func (_e *MockMsgHandler_Expecter) HandleImport(ctx interface{}, vchannel interface{}, importMsg interface{}) *MockMsgHandler_HandleImport_Call { - return &MockMsgHandler_HandleImport_Call{Call: _e.mock.On("HandleImport", ctx, vchannel, importMsg)} -} - -func (_c *MockMsgHandler_HandleImport_Call) Run(run func(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg)) *MockMsgHandler_HandleImport_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.ImportMsg)) - }) - return _c -} - -func (_c *MockMsgHandler_HandleImport_Call) Return(_a0 error) *MockMsgHandler_HandleImport_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockMsgHandler_HandleImport_Call) RunAndReturn(run func(context.Context, string, *msgpb.ImportMsg) error) *MockMsgHandler_HandleImport_Call { - _c.Call.Return(run) - return _c -} - // HandleManualFlush provides a mock function with given fields: flushMsg func (_m *MockMsgHandler) HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error { ret := _m.Called(flushMsg) diff --git a/internal/streamingcoord/server/broadcaster/broadcast_manager.go b/internal/streamingcoord/server/broadcaster/broadcast_manager.go index df48e0fb77..3f59beaebf 100644 --- a/internal/streamingcoord/server/broadcaster/broadcast_manager.go +++ b/internal/streamingcoord/server/broadcaster/broadcast_manager.go @@ -78,21 +78,6 @@ func (bm *broadcastTaskManager) AddTask(ctx context.Context, msg message.Broadca // assignID assigns the broadcast id to the message. func (bm *broadcastTaskManager) assignID(ctx context.Context, msg message.BroadcastMutableMessage) (message.BroadcastMutableMessage, error) { - // TODO: current implementation the header cannot be seen at flusher itself. - // only import message use it, so temporarily set the broadcast id here. - // need to refactor the message to make the broadcast header visible to flusher. - if msg.MessageType() == message.MessageTypeImport { - importMsg, err := message.AsMutableImportMessageV1(msg) - if err != nil { - return nil, err - } - body, err := importMsg.Body() - if err != nil { - return nil, err - } - return msg.WithBroadcastID(uint64(body.JobID)), nil - } - id, err := resource.Resource().IDAllocator().Allocate(ctx) if err != nil { return nil, errors.Wrapf(err, "allocate new id failed") diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_impl.go b/internal/streamingcoord/server/broadcaster/broadcaster_impl.go index 85480f257e..b52d962ab7 100644 --- a/internal/streamingcoord/server/broadcaster/broadcaster_impl.go +++ b/internal/streamingcoord/server/broadcaster/broadcaster_impl.go @@ -7,6 +7,7 @@ import ( "go.uber.org/zap" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" "github.com/milvus-io/milvus/internal/util/streamingutil/status" "github.com/milvus-io/milvus/pkg/v2/log" @@ -69,6 +70,14 @@ func (b *broadcasterImpl) Broadcast(ctx context.Context, msg message.BroadcastMu } }() + // We need to check if the message is valid before adding it to the broadcaster. + // TODO: add resource key lock here to avoid state race condition. + // TODO: add all ddl to check operation here after ddl framework is ready. + if err := registry.CallMessageCheckCallback(ctx, msg); err != nil { + b.Logger().Warn("check message ack callback failed", zap.Error(err)) + return nil, err + } + t, err := b.manager.AddTask(ctx, msg) if err != nil { return nil, err diff --git a/internal/streamingcoord/server/broadcaster/broadcaster_test.go b/internal/streamingcoord/server/broadcaster/broadcaster_test.go index a817772ab5..9ca93499f4 100644 --- a/internal/streamingcoord/server/broadcaster/broadcaster_test.go +++ b/internal/streamingcoord/server/broadcaster/broadcaster_test.go @@ -14,6 +14,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" "github.com/milvus-io/milvus/internal/mocks/mock_metastore" "github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster" + "github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry" "github.com/milvus-io/milvus/internal/streamingcoord/server/resource" internaltypes "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/idalloc" @@ -28,6 +29,7 @@ import ( ) func TestBroadcaster(t *testing.T) { + registry.ResetRegistration() paramtable.Init() meta := mock_metastore.NewMockStreamingCoordCataLog(t) diff --git a/internal/streamingcoord/server/broadcaster/registry/message_callback.go b/internal/streamingcoord/server/broadcaster/registry/ack_message_callback.go similarity index 78% rename from internal/streamingcoord/server/broadcaster/registry/message_callback.go rename to internal/streamingcoord/server/broadcaster/registry/ack_message_callback.go index 7f174d5ec8..2f5c8a6a9f 100644 --- a/internal/streamingcoord/server/broadcaster/registry/message_callback.go +++ b/internal/streamingcoord/server/broadcaster/registry/ack_message_callback.go @@ -13,23 +13,25 @@ import ( // init the message ack callbacks func init() { resetMessageAckCallbacks() + resetMessageCheckCallbacks() } // resetMessageAckCallbacks resets the message ack callbacks. func resetMessageAckCallbacks() { - messageAckCallbacks = map[message.MessageType]*syncutil.Future[MessageCallback]{ - message.MessageTypeDropPartition: syncutil.NewFuture[MessageCallback](), + messageAckCallbacks = map[message.MessageType]*syncutil.Future[MessageAckCallback]{ + message.MessageTypeDropPartition: syncutil.NewFuture[MessageAckCallback](), + message.MessageTypeImport: syncutil.NewFuture[MessageAckCallback](), } } -// MessageCallback is the callback function for the message type. -type MessageCallback = func(ctx context.Context, msg message.MutableMessage) error +// MessageAckCallback is the callback function for the message type. +type MessageAckCallback = func(ctx context.Context, msg message.MutableMessage) error // messageAckCallbacks is the map of message type to the callback function. -var messageAckCallbacks map[message.MessageType]*syncutil.Future[MessageCallback] +var messageAckCallbacks map[message.MessageType]*syncutil.Future[MessageAckCallback] // RegisterMessageAckCallback registers the callback function for the message type. -func RegisterMessageAckCallback(typ message.MessageType, callback MessageCallback) { +func RegisterMessageAckCallback(typ message.MessageType, callback MessageAckCallback) { future, ok := messageAckCallbacks[typ] if !ok { panic(fmt.Sprintf("the future of message callback for type %s is not registered", typ)) diff --git a/internal/streamingcoord/server/broadcaster/registry/message_callback_test.go b/internal/streamingcoord/server/broadcaster/registry/ack_message_callback_test.go similarity index 100% rename from internal/streamingcoord/server/broadcaster/registry/message_callback_test.go rename to internal/streamingcoord/server/broadcaster/registry/ack_message_callback_test.go diff --git a/internal/streamingcoord/server/broadcaster/registry/check_message_callback.go b/internal/streamingcoord/server/broadcaster/registry/check_message_callback.go new file mode 100644 index 0000000000..eabf7cc50b --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/registry/check_message_callback.go @@ -0,0 +1,51 @@ +package registry + +import ( + "context" + "fmt" + + "github.com/cockroachdb/errors" + + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + "github.com/milvus-io/milvus/pkg/v2/util/syncutil" +) + +// MessageCheckCallback is the callback function for the message type. +type MessageCheckCallback = func(ctx context.Context, msg message.BroadcastMutableMessage) error + +// resetMessageCheckCallbacks resets the message check callbacks. +func resetMessageCheckCallbacks() { + messageCheckCallbacks = map[message.MessageType]*syncutil.Future[MessageCheckCallback]{ + message.MessageTypeImport: syncutil.NewFuture[MessageCheckCallback](), + } +} + +// messageCheckCallbacks is the map of message type to the callback function. +var messageCheckCallbacks map[message.MessageType]*syncutil.Future[MessageCheckCallback] + +// RegisterMessageCheckCallback registers the callback function for the message type. +func RegisterMessageCheckCallback(typ message.MessageType, callback MessageCheckCallback) { + future, ok := messageCheckCallbacks[typ] + if !ok { + panic(fmt.Sprintf("the future of check message callback for type %s is not registered", typ)) + } + if future.Ready() { + // only for test, the register callback should be called once and only once + return + } + future.Set(callback) +} + +// CallMessageCheckCallback calls the callback function for the message type. +func CallMessageCheckCallback(ctx context.Context, msg message.BroadcastMutableMessage) error { + callbackFuture, ok := messageCheckCallbacks[msg.MessageType()] + if !ok { + // No callback need tobe called, return nil + return nil + } + callback, err := callbackFuture.GetWithContext(ctx) + if err != nil { + return errors.Wrap(err, "when waiting callback registered") + } + return callback(ctx, msg) +} diff --git a/internal/streamingcoord/server/broadcaster/registry/check_message_callback_test.go b/internal/streamingcoord/server/broadcaster/registry/check_message_callback_test.go new file mode 100644 index 0000000000..92c714acf1 --- /dev/null +++ b/internal/streamingcoord/server/broadcaster/registry/check_message_callback_test.go @@ -0,0 +1,50 @@ +package registry + +import ( + "context" + "testing" + "time" + + "github.com/cockroachdb/errors" + "github.com/stretchr/testify/assert" + + "github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_message" + "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" +) + +func TestCheckMessageCallbackRegistration(t *testing.T) { + // Reset callbacks before test + resetMessageCheckCallbacks() + + // Test registering a callback + called := false + callback := func(ctx context.Context, msg message.BroadcastMutableMessage) error { + called = true + return nil + } + + // Register callback for DropPartition message type + RegisterMessageCheckCallback(message.MessageTypeImport, callback) + + // Verify callback was registered + callbackFuture, ok := messageCheckCallbacks[message.MessageTypeImport] + assert.True(t, ok) + assert.NotNil(t, callbackFuture) + + // Create a mock message + msg := mock_message.NewMockBroadcastMutableMessage(t) + msg.EXPECT().MessageType().Return(message.MessageTypeImport) + + // Call the callback + err := CallMessageCheckCallback(context.Background(), msg) + assert.NoError(t, err) + assert.True(t, called) + + resetMessageCheckCallbacks() + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + err = CallMessageCheckCallback(ctx, msg) + assert.Error(t, err) + assert.True(t, errors.Is(err, context.DeadlineExceeded)) +} diff --git a/internal/streamingcoord/server/broadcaster/registry/test_utility.go b/internal/streamingcoord/server/broadcaster/registry/test_utility.go index 8b13f23910..5578d3ab78 100644 --- a/internal/streamingcoord/server/broadcaster/registry/test_utility.go +++ b/internal/streamingcoord/server/broadcaster/registry/test_utility.go @@ -10,4 +10,5 @@ func ResetRegistration() { localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]() localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]() resetMessageAckCallbacks() + resetMessageCheckCallbacks() } diff --git a/internal/streamingnode/server/flusher/flusherimpl/msg_handler_impl.go b/internal/streamingnode/server/flusher/flusherimpl/msg_handler_impl.go index f3f3ba3347..f8a1480d51 100644 --- a/internal/streamingnode/server/flusher/flusherimpl/msg_handler_impl.go +++ b/internal/streamingnode/server/flusher/flusherimpl/msg_handler_impl.go @@ -20,18 +20,13 @@ import ( "context" "github.com/cockroachdb/errors" - "github.com/samber/lo" "go.uber.org/zap" - "github.com/milvus-io/milvus-proto/go-api/v2/msgpb" - "github.com/milvus-io/milvus/internal/flushcommon/util" "github.com/milvus-io/milvus/internal/flushcommon/writebuffer" "github.com/milvus-io/milvus/internal/streamingnode/server/resource" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/proto/datapb" - "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" - "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/retry" ) @@ -109,34 +104,3 @@ func (impl *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFl func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error { return impl.wbMgr.SealSegments(context.Background(), msg.VChannel(), msg.Header().FlushedSegmentIds) } - -func (impl *msgHandlerImpl) HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error { - return retry.Do(ctx, func() (err error) { - client, err := resource.Resource().MixCoordClient().GetWithContext(ctx) - if err != nil { - return err - } - importResp, err := client.ImportV2(ctx, &internalpb.ImportRequestInternal{ - CollectionID: importMsg.GetCollectionID(), - CollectionName: importMsg.GetCollectionName(), - PartitionIDs: importMsg.GetPartitionIDs(), - ChannelNames: []string{vchannel}, - Schema: importMsg.GetSchema(), - Files: lo.Map(importMsg.GetFiles(), util.ConvertInternalImportFile), - Options: funcutil.Map2KeyValuePair(importMsg.GetOptions()), - DataTimestamp: importMsg.GetBase().GetTimestamp(), - JobID: importMsg.GetJobID(), - }) - err = merr.CheckRPCCall(importResp, err) - if errors.Is(err, merr.ErrCollectionNotFound) { - log.Ctx(ctx).Warn("import message failed because of collection not found, skip it", zap.String("job_id", importResp.GetJobID()), zap.Error(err)) - return nil - } - if err != nil { - log.Ctx(ctx).Warn("import message failed", zap.String("job_id", importResp.GetJobID()), zap.Error(err)) - return err - } - log.Ctx(ctx).Info("import message handled", zap.String("job_id", importResp.GetJobID())) - return nil - }, retry.AttemptAlways()) -} diff --git a/pkg/.mockery_pkg.yaml b/pkg/.mockery_pkg.yaml index e559350c2e..bf47d9abbe 100644 --- a/pkg/.mockery_pkg.yaml +++ b/pkg/.mockery_pkg.yaml @@ -14,6 +14,7 @@ packages: ImmutableMessage: ImmutableTxnMessage: MutableMessage: + BroadcastMutableMessage: RProperties: github.com/milvus-io/milvus/pkg/v2/streaming/walimpls: interfaces: diff --git a/pkg/mocks/streaming/util/mock_message/mock_BroadcastMutableMessage.go b/pkg/mocks/streaming/util/mock_message/mock_BroadcastMutableMessage.go new file mode 100644 index 0000000000..4ef185f707 --- /dev/null +++ b/pkg/mocks/streaming/util/mock_message/mock_BroadcastMutableMessage.go @@ -0,0 +1,636 @@ +// Code generated by mockery v2.53.3. DO NOT EDIT. + +package mock_message + +import ( + message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message" + mock "github.com/stretchr/testify/mock" + + zapcore "go.uber.org/zap/zapcore" +) + +// MockBroadcastMutableMessage is an autogenerated mock type for the BroadcastMutableMessage type +type MockBroadcastMutableMessage struct { + mock.Mock +} + +type MockBroadcastMutableMessage_Expecter struct { + mock *mock.Mock +} + +func (_m *MockBroadcastMutableMessage) EXPECT() *MockBroadcastMutableMessage_Expecter { + return &MockBroadcastMutableMessage_Expecter{mock: &_m.Mock} +} + +// BarrierTimeTick provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) BarrierTimeTick() uint64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for BarrierTimeTick") + } + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// MockBroadcastMutableMessage_BarrierTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BarrierTimeTick' +type MockBroadcastMutableMessage_BarrierTimeTick_Call struct { + *mock.Call +} + +// BarrierTimeTick is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) BarrierTimeTick() *MockBroadcastMutableMessage_BarrierTimeTick_Call { + return &MockBroadcastMutableMessage_BarrierTimeTick_Call{Call: _e.mock.On("BarrierTimeTick")} +} + +func (_c *MockBroadcastMutableMessage_BarrierTimeTick_Call) Run(run func()) *MockBroadcastMutableMessage_BarrierTimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_BarrierTimeTick_Call) Return(_a0 uint64) *MockBroadcastMutableMessage_BarrierTimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_BarrierTimeTick_Call) RunAndReturn(run func() uint64) *MockBroadcastMutableMessage_BarrierTimeTick_Call { + _c.Call.Return(run) + return _c +} + +// BroadcastHeader provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) BroadcastHeader() *message.BroadcastHeader { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for BroadcastHeader") + } + + var r0 *message.BroadcastHeader + if rf, ok := ret.Get(0).(func() *message.BroadcastHeader); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*message.BroadcastHeader) + } + } + + return r0 +} + +// MockBroadcastMutableMessage_BroadcastHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BroadcastHeader' +type MockBroadcastMutableMessage_BroadcastHeader_Call struct { + *mock.Call +} + +// BroadcastHeader is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) BroadcastHeader() *MockBroadcastMutableMessage_BroadcastHeader_Call { + return &MockBroadcastMutableMessage_BroadcastHeader_Call{Call: _e.mock.On("BroadcastHeader")} +} + +func (_c *MockBroadcastMutableMessage_BroadcastHeader_Call) Run(run func()) *MockBroadcastMutableMessage_BroadcastHeader_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_BroadcastHeader_Call) Return(_a0 *message.BroadcastHeader) *MockBroadcastMutableMessage_BroadcastHeader_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_BroadcastHeader_Call) RunAndReturn(run func() *message.BroadcastHeader) *MockBroadcastMutableMessage_BroadcastHeader_Call { + _c.Call.Return(run) + return _c +} + +// EstimateSize provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) EstimateSize() int { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for EstimateSize") + } + + var r0 int + if rf, ok := ret.Get(0).(func() int); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(int) + } + + return r0 +} + +// MockBroadcastMutableMessage_EstimateSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EstimateSize' +type MockBroadcastMutableMessage_EstimateSize_Call struct { + *mock.Call +} + +// EstimateSize is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) EstimateSize() *MockBroadcastMutableMessage_EstimateSize_Call { + return &MockBroadcastMutableMessage_EstimateSize_Call{Call: _e.mock.On("EstimateSize")} +} + +func (_c *MockBroadcastMutableMessage_EstimateSize_Call) Run(run func()) *MockBroadcastMutableMessage_EstimateSize_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_EstimateSize_Call) Return(_a0 int) *MockBroadcastMutableMessage_EstimateSize_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_EstimateSize_Call) RunAndReturn(run func() int) *MockBroadcastMutableMessage_EstimateSize_Call { + _c.Call.Return(run) + return _c +} + +// IsPersisted provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) IsPersisted() bool { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for IsPersisted") + } + + var r0 bool + if rf, ok := ret.Get(0).(func() bool); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(bool) + } + + return r0 +} + +// MockBroadcastMutableMessage_IsPersisted_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsPersisted' +type MockBroadcastMutableMessage_IsPersisted_Call struct { + *mock.Call +} + +// IsPersisted is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) IsPersisted() *MockBroadcastMutableMessage_IsPersisted_Call { + return &MockBroadcastMutableMessage_IsPersisted_Call{Call: _e.mock.On("IsPersisted")} +} + +func (_c *MockBroadcastMutableMessage_IsPersisted_Call) Run(run func()) *MockBroadcastMutableMessage_IsPersisted_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_IsPersisted_Call) Return(_a0 bool) *MockBroadcastMutableMessage_IsPersisted_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_IsPersisted_Call) RunAndReturn(run func() bool) *MockBroadcastMutableMessage_IsPersisted_Call { + _c.Call.Return(run) + return _c +} + +// MarshalLogObject provides a mock function with given fields: _a0 +func (_m *MockBroadcastMutableMessage) MarshalLogObject(_a0 zapcore.ObjectEncoder) error { + ret := _m.Called(_a0) + + if len(ret) == 0 { + panic("no return value specified for MarshalLogObject") + } + + var r0 error + if rf, ok := ret.Get(0).(func(zapcore.ObjectEncoder) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockBroadcastMutableMessage_MarshalLogObject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MarshalLogObject' +type MockBroadcastMutableMessage_MarshalLogObject_Call struct { + *mock.Call +} + +// MarshalLogObject is a helper method to define mock.On call +// - _a0 zapcore.ObjectEncoder +func (_e *MockBroadcastMutableMessage_Expecter) MarshalLogObject(_a0 interface{}) *MockBroadcastMutableMessage_MarshalLogObject_Call { + return &MockBroadcastMutableMessage_MarshalLogObject_Call{Call: _e.mock.On("MarshalLogObject", _a0)} +} + +func (_c *MockBroadcastMutableMessage_MarshalLogObject_Call) Run(run func(_a0 zapcore.ObjectEncoder)) *MockBroadcastMutableMessage_MarshalLogObject_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(zapcore.ObjectEncoder)) + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_MarshalLogObject_Call) Return(_a0 error) *MockBroadcastMutableMessage_MarshalLogObject_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_MarshalLogObject_Call) RunAndReturn(run func(zapcore.ObjectEncoder) error) *MockBroadcastMutableMessage_MarshalLogObject_Call { + _c.Call.Return(run) + return _c +} + +// MessageType provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) MessageType() message.MessageType { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for MessageType") + } + + var r0 message.MessageType + if rf, ok := ret.Get(0).(func() message.MessageType); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(message.MessageType) + } + + return r0 +} + +// MockBroadcastMutableMessage_MessageType_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MessageType' +type MockBroadcastMutableMessage_MessageType_Call struct { + *mock.Call +} + +// MessageType is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) MessageType() *MockBroadcastMutableMessage_MessageType_Call { + return &MockBroadcastMutableMessage_MessageType_Call{Call: _e.mock.On("MessageType")} +} + +func (_c *MockBroadcastMutableMessage_MessageType_Call) Run(run func()) *MockBroadcastMutableMessage_MessageType_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_MessageType_Call) Return(_a0 message.MessageType) *MockBroadcastMutableMessage_MessageType_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_MessageType_Call) RunAndReturn(run func() message.MessageType) *MockBroadcastMutableMessage_MessageType_Call { + _c.Call.Return(run) + return _c +} + +// Payload provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) Payload() []byte { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Payload") + } + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + return r0 +} + +// MockBroadcastMutableMessage_Payload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Payload' +type MockBroadcastMutableMessage_Payload_Call struct { + *mock.Call +} + +// Payload is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) Payload() *MockBroadcastMutableMessage_Payload_Call { + return &MockBroadcastMutableMessage_Payload_Call{Call: _e.mock.On("Payload")} +} + +func (_c *MockBroadcastMutableMessage_Payload_Call) Run(run func()) *MockBroadcastMutableMessage_Payload_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_Payload_Call) Return(_a0 []byte) *MockBroadcastMutableMessage_Payload_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_Payload_Call) RunAndReturn(run func() []byte) *MockBroadcastMutableMessage_Payload_Call { + _c.Call.Return(run) + return _c +} + +// Properties provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) Properties() message.RProperties { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Properties") + } + + var r0 message.RProperties + if rf, ok := ret.Get(0).(func() message.RProperties); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.RProperties) + } + } + + return r0 +} + +// MockBroadcastMutableMessage_Properties_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Properties' +type MockBroadcastMutableMessage_Properties_Call struct { + *mock.Call +} + +// Properties is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) Properties() *MockBroadcastMutableMessage_Properties_Call { + return &MockBroadcastMutableMessage_Properties_Call{Call: _e.mock.On("Properties")} +} + +func (_c *MockBroadcastMutableMessage_Properties_Call) Run(run func()) *MockBroadcastMutableMessage_Properties_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_Properties_Call) Return(_a0 message.RProperties) *MockBroadcastMutableMessage_Properties_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_Properties_Call) RunAndReturn(run func() message.RProperties) *MockBroadcastMutableMessage_Properties_Call { + _c.Call.Return(run) + return _c +} + +// SplitIntoMutableMessage provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) SplitIntoMutableMessage() []message.MutableMessage { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for SplitIntoMutableMessage") + } + + var r0 []message.MutableMessage + if rf, ok := ret.Get(0).(func() []message.MutableMessage); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]message.MutableMessage) + } + } + + return r0 +} + +// MockBroadcastMutableMessage_SplitIntoMutableMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SplitIntoMutableMessage' +type MockBroadcastMutableMessage_SplitIntoMutableMessage_Call struct { + *mock.Call +} + +// SplitIntoMutableMessage is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) SplitIntoMutableMessage() *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call { + return &MockBroadcastMutableMessage_SplitIntoMutableMessage_Call{Call: _e.mock.On("SplitIntoMutableMessage")} +} + +func (_c *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call) Run(run func()) *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call) Return(_a0 []message.MutableMessage) *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call) RunAndReturn(run func() []message.MutableMessage) *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call { + _c.Call.Return(run) + return _c +} + +// TimeTick provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) TimeTick() uint64 { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TimeTick") + } + + var r0 uint64 + if rf, ok := ret.Get(0).(func() uint64); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(uint64) + } + + return r0 +} + +// MockBroadcastMutableMessage_TimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TimeTick' +type MockBroadcastMutableMessage_TimeTick_Call struct { + *mock.Call +} + +// TimeTick is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) TimeTick() *MockBroadcastMutableMessage_TimeTick_Call { + return &MockBroadcastMutableMessage_TimeTick_Call{Call: _e.mock.On("TimeTick")} +} + +func (_c *MockBroadcastMutableMessage_TimeTick_Call) Run(run func()) *MockBroadcastMutableMessage_TimeTick_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_TimeTick_Call) Return(_a0 uint64) *MockBroadcastMutableMessage_TimeTick_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_TimeTick_Call) RunAndReturn(run func() uint64) *MockBroadcastMutableMessage_TimeTick_Call { + _c.Call.Return(run) + return _c +} + +// TxnContext provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) TxnContext() *message.TxnContext { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for TxnContext") + } + + var r0 *message.TxnContext + if rf, ok := ret.Get(0).(func() *message.TxnContext); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*message.TxnContext) + } + } + + return r0 +} + +// MockBroadcastMutableMessage_TxnContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TxnContext' +type MockBroadcastMutableMessage_TxnContext_Call struct { + *mock.Call +} + +// TxnContext is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) TxnContext() *MockBroadcastMutableMessage_TxnContext_Call { + return &MockBroadcastMutableMessage_TxnContext_Call{Call: _e.mock.On("TxnContext")} +} + +func (_c *MockBroadcastMutableMessage_TxnContext_Call) Run(run func()) *MockBroadcastMutableMessage_TxnContext_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_TxnContext_Call) Return(_a0 *message.TxnContext) *MockBroadcastMutableMessage_TxnContext_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_TxnContext_Call) RunAndReturn(run func() *message.TxnContext) *MockBroadcastMutableMessage_TxnContext_Call { + _c.Call.Return(run) + return _c +} + +// Version provides a mock function with no fields +func (_m *MockBroadcastMutableMessage) Version() message.Version { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Version") + } + + var r0 message.Version + if rf, ok := ret.Get(0).(func() message.Version); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(message.Version) + } + + return r0 +} + +// MockBroadcastMutableMessage_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version' +type MockBroadcastMutableMessage_Version_Call struct { + *mock.Call +} + +// Version is a helper method to define mock.On call +func (_e *MockBroadcastMutableMessage_Expecter) Version() *MockBroadcastMutableMessage_Version_Call { + return &MockBroadcastMutableMessage_Version_Call{Call: _e.mock.On("Version")} +} + +func (_c *MockBroadcastMutableMessage_Version_Call) Run(run func()) *MockBroadcastMutableMessage_Version_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_Version_Call) Return(_a0 message.Version) *MockBroadcastMutableMessage_Version_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_Version_Call) RunAndReturn(run func() message.Version) *MockBroadcastMutableMessage_Version_Call { + _c.Call.Return(run) + return _c +} + +// WithBroadcastID provides a mock function with given fields: broadcastID +func (_m *MockBroadcastMutableMessage) WithBroadcastID(broadcastID uint64) message.BroadcastMutableMessage { + ret := _m.Called(broadcastID) + + if len(ret) == 0 { + panic("no return value specified for WithBroadcastID") + } + + var r0 message.BroadcastMutableMessage + if rf, ok := ret.Get(0).(func(uint64) message.BroadcastMutableMessage); ok { + r0 = rf(broadcastID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(message.BroadcastMutableMessage) + } + } + + return r0 +} + +// MockBroadcastMutableMessage_WithBroadcastID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithBroadcastID' +type MockBroadcastMutableMessage_WithBroadcastID_Call struct { + *mock.Call +} + +// WithBroadcastID is a helper method to define mock.On call +// - broadcastID uint64 +func (_e *MockBroadcastMutableMessage_Expecter) WithBroadcastID(broadcastID interface{}) *MockBroadcastMutableMessage_WithBroadcastID_Call { + return &MockBroadcastMutableMessage_WithBroadcastID_Call{Call: _e.mock.On("WithBroadcastID", broadcastID)} +} + +func (_c *MockBroadcastMutableMessage_WithBroadcastID_Call) Run(run func(broadcastID uint64)) *MockBroadcastMutableMessage_WithBroadcastID_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(uint64)) + }) + return _c +} + +func (_c *MockBroadcastMutableMessage_WithBroadcastID_Call) Return(_a0 message.BroadcastMutableMessage) *MockBroadcastMutableMessage_WithBroadcastID_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockBroadcastMutableMessage_WithBroadcastID_Call) RunAndReturn(run func(uint64) message.BroadcastMutableMessage) *MockBroadcastMutableMessage_WithBroadcastID_Call { + _c.Call.Return(run) + return _c +} + +// NewMockBroadcastMutableMessage creates a new instance of MockBroadcastMutableMessage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockBroadcastMutableMessage(t interface { + mock.TestingT + Cleanup(func()) +}) *MockBroadcastMutableMessage { + mock := &MockBroadcastMutableMessage{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/pkg/streaming/util/message/message.go b/pkg/streaming/util/message/message.go index 18193f32b9..a2bf796d89 100644 --- a/pkg/streaming/util/message/message.go +++ b/pkg/streaming/util/message/message.go @@ -167,6 +167,9 @@ type specializedMutableMessage[H proto.Message, B proto.Message] interface { // !!! Do these will trigger a unmarshal operation, so it should be used with caution. Body() (B, error) + // MustBody return the message body, panic if error occurs. + MustBody() B + // OverwriteHeader overwrites the message header. OverwriteHeader(header H) } diff --git a/pkg/streaming/util/message/message_test.go b/pkg/streaming/util/message/message_test.go index e96491a7ed..fd5c2404fd 100644 --- a/pkg/streaming/util/message/message_test.go +++ b/pkg/streaming/util/message/message_test.go @@ -74,6 +74,8 @@ func TestBroadcast(t *testing.T) { assert.Equal(t, uint64(1), msgs[1].BroadcastHeader().BroadcastID) assert.Len(t, msgs[0].BroadcastHeader().ResourceKeys, 2) assert.ElementsMatch(t, []string{"v1", "v2"}, []string{msgs[0].VChannel(), msgs[1].VChannel()}) + + MustAsMutableCreateCollectionMessageV1(msg) } func TestCiper(t *testing.T) { diff --git a/pkg/streaming/util/message/specialized_message.go b/pkg/streaming/util/message/specialized_message.go index 8065598410..4853311cb6 100644 --- a/pkg/streaming/util/message/specialized_message.go +++ b/pkg/streaming/util/message/specialized_message.go @@ -349,6 +349,15 @@ func (m *specializedMutableMessageImpl[H, B]) Body() (B, error) { return unmarshalProtoB[B](m.Payload()) } +// MustBody returns the message body. +func (m *specializedMutableMessageImpl[H, B]) MustBody() B { + b, err := m.Body() + if err != nil { + panic(fmt.Sprintf("failed to unmarshal specialized body,%s", err.Error())) + } + return b +} + // OverwriteMessageHeader overwrites the message header. func (m *specializedMutableMessageImpl[H, B]) OverwriteHeader(header H) { m.header = header