fix: Pre-check import message to prevent pipeline block indefinitely (#42415)

Pre-check import message to prevent pipeline block indefinitely.

issue: https://github.com/milvus-io/milvus/issues/42414

---------

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
Co-authored-by: chyezh <chyezh@outlook.com>
This commit is contained in:
yihao.dai 2025-06-11 13:40:38 +08:00 committed by GitHub
parent e7c0a6ffbb
commit e6da4a64b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1897 additions and 254 deletions

View File

@ -32,6 +32,7 @@ packages:
ChannelManager:
SubCluster:
StatsJobManager:
ImportMeta:
github.com/milvus-io/milvus/internal/datacoord/allocator:
interfaces:
Allocator:

View File

@ -22,6 +22,7 @@ import (
"math"
"path"
"sort"
"sync"
"time"
"github.com/cockroachdb/errors"
@ -39,6 +40,8 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/taskcommon"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/hardware"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
@ -549,7 +552,9 @@ func getIndexBuildingProgress(ctx context.Context, jobID int64, importMeta Impor
// 10%: Completed
// TODO: Wrap a function to map status to user status.
// TODO: Save these progress to job instead of recalculating.
func GetJobProgress(ctx context.Context, jobID int64, importMeta ImportMeta, meta *meta, sjm StatsInspector) (int64, internalpb.ImportJobState, int64, int64, string) {
func GetJobProgress(ctx context.Context, jobID int64,
importMeta ImportMeta, meta *meta, sjm StatsInspector,
) (int64, internalpb.ImportJobState, int64, int64, string) {
job := importMeta.GetJob(ctx, jobID)
if job == nil {
return 0, internalpb.ImportJobState_Failed, 0, 0, fmt.Sprintf("import job does not exist, jobID=%d", jobID)
@ -627,7 +632,9 @@ func DropImportTask(task ImportTask, cluster session.Cluster, tm ImportMeta) err
return tm.UpdateTask(context.TODO(), task.GetTaskID(), UpdateNodeID(NullNodeID))
}
func ListBinlogsAndGroupBySegment(ctx context.Context, cm storage.ChunkManager, importFile *internalpb.ImportFile) ([]*internalpb.ImportFile, error) {
func ListBinlogsAndGroupBySegment(ctx context.Context,
cm storage.ChunkManager, importFile *internalpb.ImportFile,
) ([]*internalpb.ImportFile, error) {
if len(importFile.GetPaths()) == 0 {
return nil, merr.WrapErrImportFailed("no insert binlogs to import")
}
@ -668,3 +675,73 @@ func ListBinlogsAndGroupBySegment(ctx context.Context, cm storage.ChunkManager,
}
return segmentImportFiles, nil
}
// ValidateBinlogImportRequest validates the binlog import request.
func ValidateBinlogImportRequest(ctx context.Context, cm storage.ChunkManager,
reqFiles []*msgpb.ImportFile, options []*commonpb.KeyValuePair,
) error {
files := lo.Map(reqFiles, func(file *msgpb.ImportFile, _ int) *internalpb.ImportFile {
return &internalpb.ImportFile{Id: file.GetId(), Paths: file.GetPaths()}
})
_, err := ListBinlogImportRequestFiles(ctx, cm, files, options)
return err
}
// ListBinlogImportRequestFiles lists the binlog files from the request.
// TODO: dyh, remove listing binlog after backup-restore derectly passed the segments paths.
func ListBinlogImportRequestFiles(ctx context.Context, cm storage.ChunkManager,
reqFiles []*internalpb.ImportFile, options []*commonpb.KeyValuePair,
) ([]*internalpb.ImportFile, error) {
isBackup := importutilv2.IsBackup(options)
if !isBackup {
return reqFiles, nil
}
resFiles := make([]*internalpb.ImportFile, 0)
pool := conc.NewPool[struct{}](hardware.GetCPUNum() * 2)
defer pool.Release()
futures := make([]*conc.Future[struct{}], 0, len(reqFiles))
mu := &sync.Mutex{}
for _, importFile := range reqFiles {
importFile := importFile
futures = append(futures, pool.Submit(func() (struct{}, error) {
segmentPrefixes, err := ListBinlogsAndGroupBySegment(ctx, cm, importFile)
if err != nil {
return struct{}{}, err
}
mu.Lock()
defer mu.Unlock()
resFiles = append(resFiles, segmentPrefixes...)
return struct{}{}, nil
}))
}
err := conc.AwaitAll(futures...)
if err != nil {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("list binlogs failed, err=%s", err))
}
resFiles = lo.Filter(resFiles, func(file *internalpb.ImportFile, _ int) bool {
return len(file.GetPaths()) > 0
})
if len(resFiles) == 0 {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("no binlog to import, input=%s", reqFiles))
}
if len(resFiles) > paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt() {
return nil, merr.WrapErrImportFailed(fmt.Sprintf("The max number of import files should not exceed %d, but got %d",
paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt(), len(resFiles)))
}
log.Info("list binlogs prefixes for import done", zap.Int("num", len(resFiles)), zap.Any("binlog_prefixes", resFiles))
return resFiles, nil
}
// ValidateMaxImportJobExceed checks if the number of import jobs exceeds the limit.
func ValidateMaxImportJobExceed(ctx context.Context, importMeta ImportMeta) error {
maxNum := paramtable.Get().DataCoordCfg.MaxImportJobNum.GetAsInt()
executingNum := importMeta.CountJobBy(ctx, WithoutJobStates(internalpb.ImportJobState_Completed, internalpb.ImportJobState_Failed))
if executingNum >= maxNum {
return merr.WrapErrImportFailed(
fmt.Sprintf("The number of jobs has reached the limit, please try again later. " +
"If your request is set to only import a single file, " +
"please consider importing multiple files in one request for better efficiency."))
}
return nil
}

View File

@ -21,6 +21,7 @@ import (
"fmt"
"math/rand"
"path"
"strings"
"testing"
"time"
@ -31,6 +32,7 @@ import (
"go.uber.org/atomic"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/broker"
@ -885,3 +887,233 @@ func TestImportTask_MarshalJSON(t *testing.T) {
assert.Equal(t, task.GetCreatedTime(), importTask.CreatedTime)
assert.Equal(t, task.GetCompleteTime(), importTask.CompleteTime)
}
// TestImportUtil_ValidateBinlogImportRequest tests the validation of binlog import request
func TestImportUtil_ValidateBinlogImportRequest(t *testing.T) {
ctx := context.Background()
mockCM := mocks2.NewChunkManager(t)
t.Run("empty files", func(t *testing.T) {
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
err := ValidateBinlogImportRequest(ctx, mockCM, nil, options)
assert.Error(t, err)
})
t.Run("valid files - not backup", func(t *testing.T) {
files := []*msgpb.ImportFile{
{
Id: 1,
Paths: []string{"path1"},
},
}
err := ValidateBinlogImportRequest(ctx, mockCM, files, nil)
assert.NoError(t, err)
})
t.Run("invalid files - too many paths", func(t *testing.T) {
files := []*msgpb.ImportFile{
{
Id: 1,
Paths: []string{"path1", "path2", "path3"},
},
}
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
err := ValidateBinlogImportRequest(ctx, mockCM, files, options)
assert.Error(t, err)
assert.Contains(t, err.Error(), "too many input paths")
})
}
// TestImportUtil_ListBinlogImportRequestFiles tests listing binlog files from import request
func TestImportUtil_ListBinlogImportRequestFiles(t *testing.T) {
ctx := context.Background()
t.Run("empty files", func(t *testing.T) {
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
files, err := ListBinlogImportRequestFiles(ctx, nil, nil, options)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no binlog to import")
assert.Nil(t, files)
})
t.Run("not backup files", func(t *testing.T) {
reqFiles := []*internalpb.ImportFile{
{
Paths: []string{"path1"},
},
}
files, err := ListBinlogImportRequestFiles(ctx, nil, reqFiles, nil)
assert.NoError(t, err)
assert.Equal(t, reqFiles, files)
})
t.Run("backup files - list error", func(t *testing.T) {
reqFiles := []*internalpb.ImportFile{
{
Paths: []string{"path1"},
},
}
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
mockCM := mocks2.NewChunkManager(t)
mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).
Return(errors.New("mock error"))
files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options)
assert.Error(t, err)
assert.Contains(t, err.Error(), "list binlogs failed")
assert.Nil(t, files)
})
t.Run("backup files - success", func(t *testing.T) {
reqFiles := []*internalpb.ImportFile{
{
Paths: []string{"path1"},
},
}
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
mockCM := mocks2.NewChunkManager(t)
mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error {
walkFunc(&storage.ChunkObjectInfo{
FilePath: "path1",
})
return nil
})
files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options)
assert.NoError(t, err)
assert.Equal(t, 1, len(files))
assert.Equal(t, "path1", files[0].GetPaths()[0])
})
t.Run("backup files - empty result", func(t *testing.T) {
reqFiles := []*internalpb.ImportFile{
{
Paths: []string{"path1"},
},
}
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
mockCM := mocks2.NewChunkManager(t)
mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error {
return nil
})
files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options)
assert.Error(t, err)
assert.Contains(t, err.Error(), "no binlog to import")
assert.Nil(t, files)
})
t.Run("backup files - too many files", func(t *testing.T) {
maxFiles := paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt()
reqFiles := make([]*internalpb.ImportFile, maxFiles+1)
for i := 0; i < maxFiles+1; i++ {
reqFiles[i] = &internalpb.ImportFile{
Paths: []string{fmt.Sprintf("path%d", i)},
}
}
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
mockCM := mocks2.NewChunkManager(t)
mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error {
for i := 0; i < maxFiles+1; i++ {
walkFunc(&storage.ChunkObjectInfo{
FilePath: fmt.Sprintf("path%d", i),
})
}
return nil
})
files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options)
assert.Error(t, err)
assert.Contains(t, err.Error(), fmt.Sprintf("The max number of import files should not exceed %d", maxFiles))
assert.Nil(t, files)
})
t.Run("backup files - multiple files with delta", func(t *testing.T) {
reqFiles := []*internalpb.ImportFile{
{
Paths: []string{"insert/path1", "delta/path1"},
},
}
options := []*commonpb.KeyValuePair{
{
Key: importutilv2.BackupFlag,
Value: "true",
},
}
mockCM := mocks2.NewChunkManager(t)
mockCM.EXPECT().WalkWithPrefix(mock.Anything, mock.Anything, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, prefix string, recursive bool, walkFunc storage.ChunkObjectWalkFunc) error {
if strings.Contains(prefix, "insert") {
walkFunc(&storage.ChunkObjectInfo{
FilePath: "insert/path1",
})
} else if strings.Contains(prefix, "delta") {
walkFunc(&storage.ChunkObjectInfo{
FilePath: "delta/path1",
})
}
return nil
}).Times(2)
files, err := ListBinlogImportRequestFiles(ctx, mockCM, reqFiles, options)
assert.NoError(t, err)
assert.Equal(t, 1, len(files))
assert.Equal(t, 2, len(files[0].GetPaths()))
assert.Equal(t, "insert/path1", files[0].GetPaths()[0])
assert.Equal(t, "delta/path1", files[0].GetPaths()[1])
})
}
// TestImportUtil_ValidateMaxImportJobExceed tests validation of maximum import jobs
func TestImportUtil_ValidateMaxImportJobExceed(t *testing.T) {
ctx := context.Background()
t.Run("job count within limit", func(t *testing.T) {
mockImportMeta := NewMockImportMeta(t)
mockImportMeta.EXPECT().CountJobBy(mock.Anything, mock.Anything).Return(1)
err := ValidateMaxImportJobExceed(ctx, mockImportMeta)
assert.NoError(t, err)
})
t.Run("job count exceeds limit", func(t *testing.T) {
mockImportMeta := NewMockImportMeta(t)
mockImportMeta.EXPECT().CountJobBy(mock.Anything, mock.Anything).
Return(paramtable.Get().DataCoordCfg.MaxImportJobNum.GetAsInt() + 1)
err := ValidateMaxImportJobExceed(ctx, mockImportMeta)
assert.Error(t, err)
assert.Contains(t, err.Error(), "The number of jobs has reached the limit")
})
}

View File

@ -0,0 +1,679 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package datacoord
import (
context "context"
mock "github.com/stretchr/testify/mock"
)
// MockImportMeta is an autogenerated mock type for the ImportMeta type
type MockImportMeta struct {
mock.Mock
}
type MockImportMeta_Expecter struct {
mock *mock.Mock
}
func (_m *MockImportMeta) EXPECT() *MockImportMeta_Expecter {
return &MockImportMeta_Expecter{mock: &_m.Mock}
}
// AddJob provides a mock function with given fields: ctx, job
func (_m *MockImportMeta) AddJob(ctx context.Context, job ImportJob) error {
ret := _m.Called(ctx, job)
if len(ret) == 0 {
panic("no return value specified for AddJob")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, ImportJob) error); ok {
r0 = rf(ctx, job)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockImportMeta_AddJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddJob'
type MockImportMeta_AddJob_Call struct {
*mock.Call
}
// AddJob is a helper method to define mock.On call
// - ctx context.Context
// - job ImportJob
func (_e *MockImportMeta_Expecter) AddJob(ctx interface{}, job interface{}) *MockImportMeta_AddJob_Call {
return &MockImportMeta_AddJob_Call{Call: _e.mock.On("AddJob", ctx, job)}
}
func (_c *MockImportMeta_AddJob_Call) Run(run func(ctx context.Context, job ImportJob)) *MockImportMeta_AddJob_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(ImportJob))
})
return _c
}
func (_c *MockImportMeta_AddJob_Call) Return(_a0 error) *MockImportMeta_AddJob_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_AddJob_Call) RunAndReturn(run func(context.Context, ImportJob) error) *MockImportMeta_AddJob_Call {
_c.Call.Return(run)
return _c
}
// AddTask provides a mock function with given fields: ctx, task
func (_m *MockImportMeta) AddTask(ctx context.Context, task ImportTask) error {
ret := _m.Called(ctx, task)
if len(ret) == 0 {
panic("no return value specified for AddTask")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, ImportTask) error); ok {
r0 = rf(ctx, task)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockImportMeta_AddTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddTask'
type MockImportMeta_AddTask_Call struct {
*mock.Call
}
// AddTask is a helper method to define mock.On call
// - ctx context.Context
// - task ImportTask
func (_e *MockImportMeta_Expecter) AddTask(ctx interface{}, task interface{}) *MockImportMeta_AddTask_Call {
return &MockImportMeta_AddTask_Call{Call: _e.mock.On("AddTask", ctx, task)}
}
func (_c *MockImportMeta_AddTask_Call) Run(run func(ctx context.Context, task ImportTask)) *MockImportMeta_AddTask_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(ImportTask))
})
return _c
}
func (_c *MockImportMeta_AddTask_Call) Return(_a0 error) *MockImportMeta_AddTask_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_AddTask_Call) RunAndReturn(run func(context.Context, ImportTask) error) *MockImportMeta_AddTask_Call {
_c.Call.Return(run)
return _c
}
// CountJobBy provides a mock function with given fields: ctx, filters
func (_m *MockImportMeta) CountJobBy(ctx context.Context, filters ...ImportJobFilter) int {
_va := make([]interface{}, len(filters))
for _i := range filters {
_va[_i] = filters[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for CountJobBy")
}
var r0 int
if rf, ok := ret.Get(0).(func(context.Context, ...ImportJobFilter) int); ok {
r0 = rf(ctx, filters...)
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// MockImportMeta_CountJobBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CountJobBy'
type MockImportMeta_CountJobBy_Call struct {
*mock.Call
}
// CountJobBy is a helper method to define mock.On call
// - ctx context.Context
// - filters ...ImportJobFilter
func (_e *MockImportMeta_Expecter) CountJobBy(ctx interface{}, filters ...interface{}) *MockImportMeta_CountJobBy_Call {
return &MockImportMeta_CountJobBy_Call{Call: _e.mock.On("CountJobBy",
append([]interface{}{ctx}, filters...)...)}
}
func (_c *MockImportMeta_CountJobBy_Call) Run(run func(ctx context.Context, filters ...ImportJobFilter)) *MockImportMeta_CountJobBy_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]ImportJobFilter, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(ImportJobFilter)
}
}
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
func (_c *MockImportMeta_CountJobBy_Call) Return(_a0 int) *MockImportMeta_CountJobBy_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_CountJobBy_Call) RunAndReturn(run func(context.Context, ...ImportJobFilter) int) *MockImportMeta_CountJobBy_Call {
_c.Call.Return(run)
return _c
}
// GetJob provides a mock function with given fields: ctx, jobID
func (_m *MockImportMeta) GetJob(ctx context.Context, jobID int64) ImportJob {
ret := _m.Called(ctx, jobID)
if len(ret) == 0 {
panic("no return value specified for GetJob")
}
var r0 ImportJob
if rf, ok := ret.Get(0).(func(context.Context, int64) ImportJob); ok {
r0 = rf(ctx, jobID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(ImportJob)
}
}
return r0
}
// MockImportMeta_GetJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetJob'
type MockImportMeta_GetJob_Call struct {
*mock.Call
}
// GetJob is a helper method to define mock.On call
// - ctx context.Context
// - jobID int64
func (_e *MockImportMeta_Expecter) GetJob(ctx interface{}, jobID interface{}) *MockImportMeta_GetJob_Call {
return &MockImportMeta_GetJob_Call{Call: _e.mock.On("GetJob", ctx, jobID)}
}
func (_c *MockImportMeta_GetJob_Call) Run(run func(ctx context.Context, jobID int64)) *MockImportMeta_GetJob_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockImportMeta_GetJob_Call) Return(_a0 ImportJob) *MockImportMeta_GetJob_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_GetJob_Call) RunAndReturn(run func(context.Context, int64) ImportJob) *MockImportMeta_GetJob_Call {
_c.Call.Return(run)
return _c
}
// GetJobBy provides a mock function with given fields: ctx, filters
func (_m *MockImportMeta) GetJobBy(ctx context.Context, filters ...ImportJobFilter) []ImportJob {
_va := make([]interface{}, len(filters))
for _i := range filters {
_va[_i] = filters[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for GetJobBy")
}
var r0 []ImportJob
if rf, ok := ret.Get(0).(func(context.Context, ...ImportJobFilter) []ImportJob); ok {
r0 = rf(ctx, filters...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]ImportJob)
}
}
return r0
}
// MockImportMeta_GetJobBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetJobBy'
type MockImportMeta_GetJobBy_Call struct {
*mock.Call
}
// GetJobBy is a helper method to define mock.On call
// - ctx context.Context
// - filters ...ImportJobFilter
func (_e *MockImportMeta_Expecter) GetJobBy(ctx interface{}, filters ...interface{}) *MockImportMeta_GetJobBy_Call {
return &MockImportMeta_GetJobBy_Call{Call: _e.mock.On("GetJobBy",
append([]interface{}{ctx}, filters...)...)}
}
func (_c *MockImportMeta_GetJobBy_Call) Run(run func(ctx context.Context, filters ...ImportJobFilter)) *MockImportMeta_GetJobBy_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]ImportJobFilter, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(ImportJobFilter)
}
}
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
func (_c *MockImportMeta_GetJobBy_Call) Return(_a0 []ImportJob) *MockImportMeta_GetJobBy_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_GetJobBy_Call) RunAndReturn(run func(context.Context, ...ImportJobFilter) []ImportJob) *MockImportMeta_GetJobBy_Call {
_c.Call.Return(run)
return _c
}
// GetTask provides a mock function with given fields: ctx, taskID
func (_m *MockImportMeta) GetTask(ctx context.Context, taskID int64) ImportTask {
ret := _m.Called(ctx, taskID)
if len(ret) == 0 {
panic("no return value specified for GetTask")
}
var r0 ImportTask
if rf, ok := ret.Get(0).(func(context.Context, int64) ImportTask); ok {
r0 = rf(ctx, taskID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(ImportTask)
}
}
return r0
}
// MockImportMeta_GetTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTask'
type MockImportMeta_GetTask_Call struct {
*mock.Call
}
// GetTask is a helper method to define mock.On call
// - ctx context.Context
// - taskID int64
func (_e *MockImportMeta_Expecter) GetTask(ctx interface{}, taskID interface{}) *MockImportMeta_GetTask_Call {
return &MockImportMeta_GetTask_Call{Call: _e.mock.On("GetTask", ctx, taskID)}
}
func (_c *MockImportMeta_GetTask_Call) Run(run func(ctx context.Context, taskID int64)) *MockImportMeta_GetTask_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockImportMeta_GetTask_Call) Return(_a0 ImportTask) *MockImportMeta_GetTask_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_GetTask_Call) RunAndReturn(run func(context.Context, int64) ImportTask) *MockImportMeta_GetTask_Call {
_c.Call.Return(run)
return _c
}
// GetTaskBy provides a mock function with given fields: ctx, filters
func (_m *MockImportMeta) GetTaskBy(ctx context.Context, filters ...ImportTaskFilter) []ImportTask {
_va := make([]interface{}, len(filters))
for _i := range filters {
_va[_i] = filters[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for GetTaskBy")
}
var r0 []ImportTask
if rf, ok := ret.Get(0).(func(context.Context, ...ImportTaskFilter) []ImportTask); ok {
r0 = rf(ctx, filters...)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]ImportTask)
}
}
return r0
}
// MockImportMeta_GetTaskBy_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTaskBy'
type MockImportMeta_GetTaskBy_Call struct {
*mock.Call
}
// GetTaskBy is a helper method to define mock.On call
// - ctx context.Context
// - filters ...ImportTaskFilter
func (_e *MockImportMeta_Expecter) GetTaskBy(ctx interface{}, filters ...interface{}) *MockImportMeta_GetTaskBy_Call {
return &MockImportMeta_GetTaskBy_Call{Call: _e.mock.On("GetTaskBy",
append([]interface{}{ctx}, filters...)...)}
}
func (_c *MockImportMeta_GetTaskBy_Call) Run(run func(ctx context.Context, filters ...ImportTaskFilter)) *MockImportMeta_GetTaskBy_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]ImportTaskFilter, len(args)-1)
for i, a := range args[1:] {
if a != nil {
variadicArgs[i] = a.(ImportTaskFilter)
}
}
run(args[0].(context.Context), variadicArgs...)
})
return _c
}
func (_c *MockImportMeta_GetTaskBy_Call) Return(_a0 []ImportTask) *MockImportMeta_GetTaskBy_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_GetTaskBy_Call) RunAndReturn(run func(context.Context, ...ImportTaskFilter) []ImportTask) *MockImportMeta_GetTaskBy_Call {
_c.Call.Return(run)
return _c
}
// RemoveJob provides a mock function with given fields: ctx, jobID
func (_m *MockImportMeta) RemoveJob(ctx context.Context, jobID int64) error {
ret := _m.Called(ctx, jobID)
if len(ret) == 0 {
panic("no return value specified for RemoveJob")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = rf(ctx, jobID)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockImportMeta_RemoveJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveJob'
type MockImportMeta_RemoveJob_Call struct {
*mock.Call
}
// RemoveJob is a helper method to define mock.On call
// - ctx context.Context
// - jobID int64
func (_e *MockImportMeta_Expecter) RemoveJob(ctx interface{}, jobID interface{}) *MockImportMeta_RemoveJob_Call {
return &MockImportMeta_RemoveJob_Call{Call: _e.mock.On("RemoveJob", ctx, jobID)}
}
func (_c *MockImportMeta_RemoveJob_Call) Run(run func(ctx context.Context, jobID int64)) *MockImportMeta_RemoveJob_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockImportMeta_RemoveJob_Call) Return(_a0 error) *MockImportMeta_RemoveJob_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_RemoveJob_Call) RunAndReturn(run func(context.Context, int64) error) *MockImportMeta_RemoveJob_Call {
_c.Call.Return(run)
return _c
}
// RemoveTask provides a mock function with given fields: ctx, taskID
func (_m *MockImportMeta) RemoveTask(ctx context.Context, taskID int64) error {
ret := _m.Called(ctx, taskID)
if len(ret) == 0 {
panic("no return value specified for RemoveTask")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok {
r0 = rf(ctx, taskID)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockImportMeta_RemoveTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveTask'
type MockImportMeta_RemoveTask_Call struct {
*mock.Call
}
// RemoveTask is a helper method to define mock.On call
// - ctx context.Context
// - taskID int64
func (_e *MockImportMeta_Expecter) RemoveTask(ctx interface{}, taskID interface{}) *MockImportMeta_RemoveTask_Call {
return &MockImportMeta_RemoveTask_Call{Call: _e.mock.On("RemoveTask", ctx, taskID)}
}
func (_c *MockImportMeta_RemoveTask_Call) Run(run func(ctx context.Context, taskID int64)) *MockImportMeta_RemoveTask_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(int64))
})
return _c
}
func (_c *MockImportMeta_RemoveTask_Call) Return(_a0 error) *MockImportMeta_RemoveTask_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_RemoveTask_Call) RunAndReturn(run func(context.Context, int64) error) *MockImportMeta_RemoveTask_Call {
_c.Call.Return(run)
return _c
}
// TaskStatsJSON provides a mock function with given fields: ctx
func (_m *MockImportMeta) TaskStatsJSON(ctx context.Context) string {
ret := _m.Called(ctx)
if len(ret) == 0 {
panic("no return value specified for TaskStatsJSON")
}
var r0 string
if rf, ok := ret.Get(0).(func(context.Context) string); ok {
r0 = rf(ctx)
} else {
r0 = ret.Get(0).(string)
}
return r0
}
// MockImportMeta_TaskStatsJSON_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TaskStatsJSON'
type MockImportMeta_TaskStatsJSON_Call struct {
*mock.Call
}
// TaskStatsJSON is a helper method to define mock.On call
// - ctx context.Context
func (_e *MockImportMeta_Expecter) TaskStatsJSON(ctx interface{}) *MockImportMeta_TaskStatsJSON_Call {
return &MockImportMeta_TaskStatsJSON_Call{Call: _e.mock.On("TaskStatsJSON", ctx)}
}
func (_c *MockImportMeta_TaskStatsJSON_Call) Run(run func(ctx context.Context)) *MockImportMeta_TaskStatsJSON_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context))
})
return _c
}
func (_c *MockImportMeta_TaskStatsJSON_Call) Return(_a0 string) *MockImportMeta_TaskStatsJSON_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_TaskStatsJSON_Call) RunAndReturn(run func(context.Context) string) *MockImportMeta_TaskStatsJSON_Call {
_c.Call.Return(run)
return _c
}
// UpdateJob provides a mock function with given fields: ctx, jobID, actions
func (_m *MockImportMeta) UpdateJob(ctx context.Context, jobID int64, actions ...UpdateJobAction) error {
_va := make([]interface{}, len(actions))
for _i := range actions {
_va[_i] = actions[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, jobID)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for UpdateJob")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, ...UpdateJobAction) error); ok {
r0 = rf(ctx, jobID, actions...)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockImportMeta_UpdateJob_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateJob'
type MockImportMeta_UpdateJob_Call struct {
*mock.Call
}
// UpdateJob is a helper method to define mock.On call
// - ctx context.Context
// - jobID int64
// - actions ...UpdateJobAction
func (_e *MockImportMeta_Expecter) UpdateJob(ctx interface{}, jobID interface{}, actions ...interface{}) *MockImportMeta_UpdateJob_Call {
return &MockImportMeta_UpdateJob_Call{Call: _e.mock.On("UpdateJob",
append([]interface{}{ctx, jobID}, actions...)...)}
}
func (_c *MockImportMeta_UpdateJob_Call) Run(run func(ctx context.Context, jobID int64, actions ...UpdateJobAction)) *MockImportMeta_UpdateJob_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]UpdateJobAction, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(UpdateJobAction)
}
}
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
})
return _c
}
func (_c *MockImportMeta_UpdateJob_Call) Return(_a0 error) *MockImportMeta_UpdateJob_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_UpdateJob_Call) RunAndReturn(run func(context.Context, int64, ...UpdateJobAction) error) *MockImportMeta_UpdateJob_Call {
_c.Call.Return(run)
return _c
}
// UpdateTask provides a mock function with given fields: ctx, taskID, actions
func (_m *MockImportMeta) UpdateTask(ctx context.Context, taskID int64, actions ...UpdateAction) error {
_va := make([]interface{}, len(actions))
for _i := range actions {
_va[_i] = actions[_i]
}
var _ca []interface{}
_ca = append(_ca, ctx, taskID)
_ca = append(_ca, _va...)
ret := _m.Called(_ca...)
if len(ret) == 0 {
panic("no return value specified for UpdateTask")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int64, ...UpdateAction) error); ok {
r0 = rf(ctx, taskID, actions...)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockImportMeta_UpdateTask_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateTask'
type MockImportMeta_UpdateTask_Call struct {
*mock.Call
}
// UpdateTask is a helper method to define mock.On call
// - ctx context.Context
// - taskID int64
// - actions ...UpdateAction
func (_e *MockImportMeta_Expecter) UpdateTask(ctx interface{}, taskID interface{}, actions ...interface{}) *MockImportMeta_UpdateTask_Call {
return &MockImportMeta_UpdateTask_Call{Call: _e.mock.On("UpdateTask",
append([]interface{}{ctx, taskID}, actions...)...)}
}
func (_c *MockImportMeta_UpdateTask_Call) Run(run func(ctx context.Context, taskID int64, actions ...UpdateAction)) *MockImportMeta_UpdateTask_Call {
_c.Call.Run(func(args mock.Arguments) {
variadicArgs := make([]UpdateAction, len(args)-2)
for i, a := range args[2:] {
if a != nil {
variadicArgs[i] = a.(UpdateAction)
}
}
run(args[0].(context.Context), args[1].(int64), variadicArgs...)
})
return _c
}
func (_c *MockImportMeta_UpdateTask_Call) Return(_a0 error) *MockImportMeta_UpdateTask_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockImportMeta_UpdateTask_Call) RunAndReturn(run func(context.Context, int64, ...UpdateAction) error) *MockImportMeta_UpdateTask_Call {
_c.Call.Return(run)
return _c
}
// NewMockImportMeta creates a new instance of MockImportMeta. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockImportMeta(t interface {
mock.TestingT
Cleanup(func())
}) *MockImportMeta {
mock := &MockImportMeta{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -36,6 +36,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
globalIDAllocator "github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/datacoord/allocator"
"github.com/milvus-io/milvus/internal/datacoord/broker"
@ -49,15 +50,18 @@ import (
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/importutilv2"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/streamingutil"
"github.com/milvus-io/milvus/pkg/v2/kv"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/expr"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/logutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
@ -331,17 +335,71 @@ func (s *Server) initDataCoord() error {
log.Info("init datacoord done", zap.Int64("nodeID", paramtable.GetNodeID()), zap.String("Address", s.address))
s.initMessageAckCallback()
s.initMessageCallback()
return nil
}
// initMessageAckCallback initializes the message ack callback.
// initMessageCallback initializes the message callback.
// TODO: we should build a ddl framework to handle the message ack callback for ddl messages
func (s *Server) initMessageAckCallback() {
func (s *Server) initMessageCallback() {
registry.RegisterMessageAckCallback(message.MessageTypeDropPartition, func(ctx context.Context, msg message.MutableMessage) error {
dropPartitionMsg := message.MustAsMutableDropPartitionMessageV1(msg)
return s.NotifyDropPartition(ctx, msg.VChannel(), []int64{dropPartitionMsg.Header().PartitionId})
})
registry.RegisterMessageAckCallback(message.MessageTypeImport, func(ctx context.Context, msg message.MutableMessage) error {
importMsg := message.MustAsMutableImportMessageV1(msg)
body := importMsg.MustBody()
importResp, err := s.ImportV2(ctx, &internalpb.ImportRequestInternal{
CollectionID: body.GetCollectionID(),
CollectionName: body.GetCollectionName(),
PartitionIDs: body.GetPartitionIDs(),
ChannelNames: []string{msg.VChannel()},
Schema: body.GetSchema(),
Files: lo.Map(body.GetFiles(), func(file *msgpb.ImportFile, _ int) *internalpb.ImportFile {
return &internalpb.ImportFile{
Id: file.GetId(),
Paths: file.GetPaths(),
}
}),
Options: funcutil.Map2KeyValuePair(body.GetOptions()),
DataTimestamp: body.GetBase().GetTimestamp(),
JobID: body.GetJobID(),
})
err = merr.CheckRPCCall(importResp, err)
if errors.Is(err, merr.ErrCollectionNotFound) {
log.Ctx(ctx).Warn("import message failed because of collection not found, skip it", zap.String("job_id", importResp.GetJobID()), zap.Error(err))
return nil
}
if err != nil {
log.Ctx(ctx).Warn("import message failed", zap.String("job_id", importResp.GetJobID()), zap.Error(err))
return err
}
log.Ctx(ctx).Info("import message handled", zap.String("job_id", importResp.GetJobID()))
return nil
})
registry.RegisterMessageCheckCallback(message.MessageTypeImport, func(ctx context.Context, msg message.BroadcastMutableMessage) error {
importMsg := message.MustAsMutableImportMessageV1(msg)
b, err := importMsg.Body()
if err != nil {
return err
}
options := funcutil.Map2KeyValuePair(b.GetOptions())
_, err = importutilv2.GetTimeoutTs(options)
if err != nil {
return err
}
err = ValidateBinlogImportRequest(ctx, s.meta.chunkManager, b.GetFiles(), options)
if err != nil {
return err
}
err = ValidateMaxImportJobExceed(ctx, s.importMeta)
if err != nil {
return err
}
return nil
})
}
// Start initialize `Server` members and start loops, follow steps are taken:

View File

@ -46,8 +46,10 @@ import (
"github.com/milvus-io/milvus/internal/datacoord/broker"
"github.com/milvus-io/milvus/internal/datacoord/session"
etcdkv "github.com/milvus-io/milvus/internal/kv/etcd"
mocks2 "github.com/milvus-io/milvus/internal/metastore/mocks"
"github.com/milvus-io/milvus/internal/metastore/model"
"github.com/milvus-io/milvus/internal/mocks"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/sessionutil"
@ -57,6 +59,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/proto/indexpb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/workerpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/etcd"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
@ -2698,3 +2701,71 @@ func TestUpdateAutoBalanceConfigLoop(t *testing.T) {
wg.Wait()
})
}
func TestServer_InitMessageCallback(t *testing.T) {
ctx := context.Background()
mockCatalog := mocks2.NewDataCoordCatalog(t)
mockChunkManager := mocks.NewChunkManager(t)
mockManager := NewMockManager(t)
server := &Server{
ctx: ctx,
meta: &meta{
catalog: mockCatalog,
chunkManager: mockChunkManager,
segments: NewSegmentsInfo(),
},
importMeta: &importMeta{},
segmentManager: mockManager,
}
server.stateCode.Store(commonpb.StateCode_Abnormal)
// Test initMessageCallback
server.initMessageCallback()
// Test DropPartition message callback
dropPartitionMsg, err := message.NewDropPartitionMessageBuilderV1().
WithVChannel("test_channel").
WithHeader(&message.DropPartitionMessageHeader{
CollectionId: 1,
PartitionId: 1,
}).
WithBody(&msgpb.DropPartitionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_DropPartition,
},
}).
BuildMutable()
assert.NoError(t, err)
err = registry.CallMessageAckCallback(ctx, dropPartitionMsg)
assert.Error(t, err) // server not healthy
// Test Import message check callback
resourceKey := message.NewImportJobIDResourceKey(1)
msg, err := message.NewImportMessageBuilderV1().
WithHeader(&message.ImportMessageHeader{}).
WithBody(&msgpb.ImportMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Import,
},
}).
WithBroadcast([]string{"ch-0"}, resourceKey).
BuildBroadcast()
err = registry.CallMessageCheckCallback(ctx, msg)
assert.NoError(t, err)
// Test Import message ack callback
importMsg, err := message.NewImportMessageBuilderV1().
WithVChannel("test_channel").
WithHeader(&message.ImportMessageHeader{}).
WithBody(&msgpb.ImportMsg{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Import,
},
}).
BuildMutable()
assert.NoError(t, err)
err = registry.CallMessageAckCallback(ctx, importMsg)
assert.Error(t, err) // server not healthy
}

View File

@ -21,7 +21,6 @@ import (
"fmt"
"math"
"strconv"
"sync"
"time"
"github.com/cockroachdb/errors"
@ -44,9 +43,7 @@ import (
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/util/conc"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/hardware"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metricsinfo"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
@ -1796,59 +1793,12 @@ func (s *Server) ImportV2(ctx context.Context, in *internalpb.ImportRequestInter
files := in.GetFiles()
isBackup := importutilv2.IsBackup(in.GetOptions())
if isBackup {
files = make([]*internalpb.ImportFile, 0)
pool := conc.NewPool[struct{}](hardware.GetCPUNum() * 2)
defer pool.Release()
futures := make([]*conc.Future[struct{}], 0, len(in.GetFiles()))
mu := &sync.Mutex{}
for _, importFile := range in.GetFiles() {
importFile := importFile
futures = append(futures, pool.Submit(func() (struct{}, error) {
segmentPrefixes, err := ListBinlogsAndGroupBySegment(ctx, s.meta.chunkManager, importFile)
files, err = ListBinlogImportRequestFiles(ctx, s.meta.chunkManager, files, in.GetOptions())
if err != nil {
return struct{}{}, err
}
mu.Lock()
defer mu.Unlock()
files = append(files, segmentPrefixes...)
return struct{}{}, nil
}))
}
err = conc.AwaitAll(futures...)
if err != nil {
resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("list binlogs failed, err=%s", err)))
resp.Status = merr.Status(err)
return resp, nil
}
files = lo.Filter(files, func(file *internalpb.ImportFile, _ int) bool {
return len(file.GetPaths()) > 0
})
if len(files) == 0 {
resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("no binlog to import, input=%s", in.GetFiles())))
return resp, nil
}
if len(files) > paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt() {
resp.Status = merr.Status(merr.WrapErrImportFailed(fmt.Sprintf("The max number of import files should not exceed %d, but got %d",
paramtable.Get().DataCoordCfg.MaxFilesPerImportReq.GetAsInt(), len(files))))
return resp, nil
}
log.Info("list binlogs prefixes for import", zap.Int("num", len(files)), zap.Any("binlog_prefixes", files))
}
// The import task does not need to be controlled for the time being, and additional development is required later.
// Here is a comment, because the current importv2 communicates through messages and needs to ensure idempotence.
// Adding this part of the logic will cause importv2 to retry infinitely until the previous import task is completed.
// Check if the number of jobs exceeds the limit.
// maxNum := paramtable.Get().DataCoordCfg.MaxImportJobNum.GetAsInt()
// executingNum := s.importMeta.CountJobBy(ctx, WithoutJobStates(internalpb.ImportJobState_Completed, internalpb.ImportJobState_Failed))
// if executingNum >= maxNum {
// resp.Status = merr.Status(merr.WrapErrImportFailed(
// fmt.Sprintf("The number of jobs has reached the limit, please try again later. " +
// "If your request is set to only import a single file, " +
// "please consider importing multiple files in one request for better efficiency.")))
// return resp, nil
// }
// Allocate file ids.
idStart, _, err := s.allocator.AllocN(int64(len(files)) + 1)

View File

@ -1404,13 +1404,6 @@ func TestImportV2(t *testing.T) {
assert.Equal(t, int32(0), resp.GetStatus().GetCode())
jobs = s.importMeta.GetJobBy(context.TODO())
assert.Equal(t, 1, len(jobs))
// number of jobs reached the limit
// Params.Save(paramtable.Get().DataCoordCfg.MaxImportJobNum.Key, "1")
// resp, err = s.ImportV2(ctx, &internalpb.ImportRequestInternal{})
// assert.NoError(t, err)
// assert.True(t, errors.Is(merr.Error(resp.GetStatus()), merr.ErrImportFailed))
// Params.Reset(paramtable.Get().DataCoordCfg.MaxImportJobNum.Key)
})
t.Run("GetImportProgress", func(t *testing.T) {

View File

@ -21,21 +21,8 @@ package msghandlerimpl
import (
"context"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/flushcommon/broker"
"github.com/milvus-io/milvus/internal/flushcommon/util"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
)
type msgHandlerImpl struct {
@ -54,41 +41,6 @@ func (m *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFlush
panic("unreachable code")
}
func (m *msgHandlerImpl) HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error {
return retry.Do(ctx, func() (err error) {
defer func() {
if err == nil {
err = streaming.WAL().Broadcast().Ack(ctx, types.BroadcastAckRequest{
BroadcastID: uint64(importMsg.GetJobID()),
VChannel: vchannel,
})
}
}()
importResp, err := m.broker.ImportV2(ctx, &internalpb.ImportRequestInternal{
CollectionID: importMsg.GetCollectionID(),
CollectionName: importMsg.GetCollectionName(),
PartitionIDs: importMsg.GetPartitionIDs(),
ChannelNames: []string{vchannel},
Schema: importMsg.GetSchema(),
Files: lo.Map(importMsg.GetFiles(), util.ConvertInternalImportFile),
Options: funcutil.Map2KeyValuePair(importMsg.GetOptions()),
DataTimestamp: importMsg.GetBase().GetTimestamp(),
JobID: importMsg.GetJobID(),
})
err = merr.CheckRPCCall(importResp, err)
if errors.Is(err, merr.ErrCollectionNotFound) {
log.Ctx(ctx).Warn("import message failed because of collection not found, skip it", zap.String("job_id", importResp.GetJobID()), zap.Error(err))
return nil
}
if err != nil {
log.Ctx(ctx).Warn("import message failed", zap.String("job_id", importResp.GetJobID()), zap.Error(err))
return err
}
log.Ctx(ctx).Info("import message handled", zap.String("job_id", importResp.GetJobID()))
return nil
}, retry.AttemptAlways())
}
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error {
panic("unreachable code")
}

View File

@ -19,21 +19,16 @@
package msghandlerimpl
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/milvus-io/milvus/internal/distributed/streaming"
"github.com/milvus-io/milvus/internal/flushcommon/broker"
"github.com/milvus-io/milvus/internal/mocks/distributed/mock_streaming"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func TestMsgHandlerImpl(t *testing.T) {
paramtable.Init()
ctx := context.Background()
b := broker.NewMockBroker(t)
m := NewMsgHandlerImpl(b)
assert.Panics(t, func() {
@ -45,17 +40,4 @@ func TestMsgHandlerImpl(t *testing.T) {
assert.Panics(t, func() {
m.HandleManualFlush(nil)
})
t.Run("HandleImport success", func(t *testing.T) {
wal := mock_streaming.NewMockWALAccesser(t)
bo := mock_streaming.NewMockBroadcast(t)
wal.EXPECT().Broadcast().Return(bo)
bo.EXPECT().Ack(mock.Anything, mock.Anything).Return(nil)
streaming.SetWALForTest(wal)
defer streaming.RecoverWALForTest()
b.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(nil, assert.AnError).Once()
b.EXPECT().ImportV2(mock.Anything, mock.Anything).Return(nil, nil).Once()
err := m.HandleImport(ctx, "", nil)
assert.NoError(t, err)
})
}

View File

@ -282,21 +282,6 @@ func (ddn *ddNode) Operate(in []Msg) []Msg {
} else {
logger.Info("handle manual flush message success")
}
case commonpb.MsgType_Import:
importMsg := msg.(*msgstream.ImportMsg)
if importMsg.GetCollectionID() != ddn.collectionID {
continue
}
logger := log.With(
zap.String("vchannel", ddn.Name()),
zap.Int32("msgType", int32(msg.Type())),
)
logger.Info("receive import message")
if err := ddn.msgHandler.HandleImport(context.Background(), ddn.vChannelName, importMsg.ImportMsg); err != nil {
logger.Warn("handle import message failed", zap.Error(err))
} else {
logger.Info("handle import message success")
}
case commonpb.MsgType_AddCollectionField:
schemaMsg := msg.(*adaptor.SchemaChangeMessageBody)
header := schemaMsg.SchemaChangeMessage.Header()

View File

@ -31,8 +31,6 @@ type MsgHandler interface {
HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error
HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error
HandleSchemaChange(ctx context.Context, schemaChangeMsg message.ImmutableSchemaChangeMessageV2) error
}

View File

@ -7,8 +7,6 @@ import (
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
msgpb "github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
)
// MockMsgHandler is an autogenerated mock type for the MsgHandler type
@ -117,54 +115,6 @@ func (_c *MockMsgHandler_HandleFlush_Call) RunAndReturn(run func(message.Immutab
return _c
}
// HandleImport provides a mock function with given fields: ctx, vchannel, importMsg
func (_m *MockMsgHandler) HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error {
ret := _m.Called(ctx, vchannel, importMsg)
if len(ret) == 0 {
panic("no return value specified for HandleImport")
}
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, string, *msgpb.ImportMsg) error); ok {
r0 = rf(ctx, vchannel, importMsg)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockMsgHandler_HandleImport_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'HandleImport'
type MockMsgHandler_HandleImport_Call struct {
*mock.Call
}
// HandleImport is a helper method to define mock.On call
// - ctx context.Context
// - vchannel string
// - importMsg *msgpb.ImportMsg
func (_e *MockMsgHandler_Expecter) HandleImport(ctx interface{}, vchannel interface{}, importMsg interface{}) *MockMsgHandler_HandleImport_Call {
return &MockMsgHandler_HandleImport_Call{Call: _e.mock.On("HandleImport", ctx, vchannel, importMsg)}
}
func (_c *MockMsgHandler_HandleImport_Call) Run(run func(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg)) *MockMsgHandler_HandleImport_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(context.Context), args[1].(string), args[2].(*msgpb.ImportMsg))
})
return _c
}
func (_c *MockMsgHandler_HandleImport_Call) Return(_a0 error) *MockMsgHandler_HandleImport_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockMsgHandler_HandleImport_Call) RunAndReturn(run func(context.Context, string, *msgpb.ImportMsg) error) *MockMsgHandler_HandleImport_Call {
_c.Call.Return(run)
return _c
}
// HandleManualFlush provides a mock function with given fields: flushMsg
func (_m *MockMsgHandler) HandleManualFlush(flushMsg message.ImmutableManualFlushMessageV2) error {
ret := _m.Called(flushMsg)

View File

@ -78,21 +78,6 @@ func (bm *broadcastTaskManager) AddTask(ctx context.Context, msg message.Broadca
// assignID assigns the broadcast id to the message.
func (bm *broadcastTaskManager) assignID(ctx context.Context, msg message.BroadcastMutableMessage) (message.BroadcastMutableMessage, error) {
// TODO: current implementation the header cannot be seen at flusher itself.
// only import message use it, so temporarily set the broadcast id here.
// need to refactor the message to make the broadcast header visible to flusher.
if msg.MessageType() == message.MessageTypeImport {
importMsg, err := message.AsMutableImportMessageV1(msg)
if err != nil {
return nil, err
}
body, err := importMsg.Body()
if err != nil {
return nil, err
}
return msg.WithBroadcastID(uint64(body.JobID)), nil
}
id, err := resource.Resource().IDAllocator().Allocate(ctx)
if err != nil {
return nil, errors.Wrapf(err, "allocate new id failed")

View File

@ -7,6 +7,7 @@ import (
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
"github.com/milvus-io/milvus/internal/util/streamingutil/status"
"github.com/milvus-io/milvus/pkg/v2/log"
@ -69,6 +70,14 @@ func (b *broadcasterImpl) Broadcast(ctx context.Context, msg message.BroadcastMu
}
}()
// We need to check if the message is valid before adding it to the broadcaster.
// TODO: add resource key lock here to avoid state race condition.
// TODO: add all ddl to check operation here after ddl framework is ready.
if err := registry.CallMessageCheckCallback(ctx, msg); err != nil {
b.Logger().Warn("check message ack callback failed", zap.Error(err))
return nil, err
}
t, err := b.manager.AddTask(ctx, msg)
if err != nil {
return nil, err

View File

@ -14,6 +14,7 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/mocks/mock_metastore"
"github.com/milvus-io/milvus/internal/mocks/streamingcoord/server/mock_broadcaster"
"github.com/milvus-io/milvus/internal/streamingcoord/server/broadcaster/registry"
"github.com/milvus-io/milvus/internal/streamingcoord/server/resource"
internaltypes "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/idalloc"
@ -28,6 +29,7 @@ import (
)
func TestBroadcaster(t *testing.T) {
registry.ResetRegistration()
paramtable.Init()
meta := mock_metastore.NewMockStreamingCoordCataLog(t)

View File

@ -13,23 +13,25 @@ import (
// init the message ack callbacks
func init() {
resetMessageAckCallbacks()
resetMessageCheckCallbacks()
}
// resetMessageAckCallbacks resets the message ack callbacks.
func resetMessageAckCallbacks() {
messageAckCallbacks = map[message.MessageType]*syncutil.Future[MessageCallback]{
message.MessageTypeDropPartition: syncutil.NewFuture[MessageCallback](),
messageAckCallbacks = map[message.MessageType]*syncutil.Future[MessageAckCallback]{
message.MessageTypeDropPartition: syncutil.NewFuture[MessageAckCallback](),
message.MessageTypeImport: syncutil.NewFuture[MessageAckCallback](),
}
}
// MessageCallback is the callback function for the message type.
type MessageCallback = func(ctx context.Context, msg message.MutableMessage) error
// MessageAckCallback is the callback function for the message type.
type MessageAckCallback = func(ctx context.Context, msg message.MutableMessage) error
// messageAckCallbacks is the map of message type to the callback function.
var messageAckCallbacks map[message.MessageType]*syncutil.Future[MessageCallback]
var messageAckCallbacks map[message.MessageType]*syncutil.Future[MessageAckCallback]
// RegisterMessageAckCallback registers the callback function for the message type.
func RegisterMessageAckCallback(typ message.MessageType, callback MessageCallback) {
func RegisterMessageAckCallback(typ message.MessageType, callback MessageAckCallback) {
future, ok := messageAckCallbacks[typ]
if !ok {
panic(fmt.Sprintf("the future of message callback for type %s is not registered", typ))

View File

@ -0,0 +1,51 @@
package registry
import (
"context"
"fmt"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/syncutil"
)
// MessageCheckCallback is the callback function for the message type.
type MessageCheckCallback = func(ctx context.Context, msg message.BroadcastMutableMessage) error
// resetMessageCheckCallbacks resets the message check callbacks.
func resetMessageCheckCallbacks() {
messageCheckCallbacks = map[message.MessageType]*syncutil.Future[MessageCheckCallback]{
message.MessageTypeImport: syncutil.NewFuture[MessageCheckCallback](),
}
}
// messageCheckCallbacks is the map of message type to the callback function.
var messageCheckCallbacks map[message.MessageType]*syncutil.Future[MessageCheckCallback]
// RegisterMessageCheckCallback registers the callback function for the message type.
func RegisterMessageCheckCallback(typ message.MessageType, callback MessageCheckCallback) {
future, ok := messageCheckCallbacks[typ]
if !ok {
panic(fmt.Sprintf("the future of check message callback for type %s is not registered", typ))
}
if future.Ready() {
// only for test, the register callback should be called once and only once
return
}
future.Set(callback)
}
// CallMessageCheckCallback calls the callback function for the message type.
func CallMessageCheckCallback(ctx context.Context, msg message.BroadcastMutableMessage) error {
callbackFuture, ok := messageCheckCallbacks[msg.MessageType()]
if !ok {
// No callback need tobe called, return nil
return nil
}
callback, err := callbackFuture.GetWithContext(ctx)
if err != nil {
return errors.Wrap(err, "when waiting callback registered")
}
return callback(ctx, msg)
}

View File

@ -0,0 +1,50 @@
package registry
import (
"context"
"testing"
"time"
"github.com/cockroachdb/errors"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/pkg/v2/mocks/streaming/util/mock_message"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
)
func TestCheckMessageCallbackRegistration(t *testing.T) {
// Reset callbacks before test
resetMessageCheckCallbacks()
// Test registering a callback
called := false
callback := func(ctx context.Context, msg message.BroadcastMutableMessage) error {
called = true
return nil
}
// Register callback for DropPartition message type
RegisterMessageCheckCallback(message.MessageTypeImport, callback)
// Verify callback was registered
callbackFuture, ok := messageCheckCallbacks[message.MessageTypeImport]
assert.True(t, ok)
assert.NotNil(t, callbackFuture)
// Create a mock message
msg := mock_message.NewMockBroadcastMutableMessage(t)
msg.EXPECT().MessageType().Return(message.MessageTypeImport)
// Call the callback
err := CallMessageCheckCallback(context.Background(), msg)
assert.NoError(t, err)
assert.True(t, called)
resetMessageCheckCallbacks()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
err = CallMessageCheckCallback(ctx, msg)
assert.Error(t, err)
assert.True(t, errors.Is(err, context.DeadlineExceeded))
}

View File

@ -10,4 +10,5 @@ func ResetRegistration() {
localRegistry[AppendOperatorTypeMsgstream] = syncutil.NewFuture[AppendOperator]()
localRegistry[AppendOperatorTypeStreaming] = syncutil.NewFuture[AppendOperator]()
resetMessageAckCallbacks()
resetMessageCheckCallbacks()
}

View File

@ -20,18 +20,13 @@ import (
"context"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/msgpb"
"github.com/milvus-io/milvus/internal/flushcommon/util"
"github.com/milvus-io/milvus/internal/flushcommon/writebuffer"
"github.com/milvus-io/milvus/internal/streamingnode/server/resource"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/proto/datapb"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/retry"
)
@ -109,34 +104,3 @@ func (impl *msgHandlerImpl) HandleManualFlush(flushMsg message.ImmutableManualFl
func (impl *msgHandlerImpl) HandleSchemaChange(ctx context.Context, msg message.ImmutableSchemaChangeMessageV2) error {
return impl.wbMgr.SealSegments(context.Background(), msg.VChannel(), msg.Header().FlushedSegmentIds)
}
func (impl *msgHandlerImpl) HandleImport(ctx context.Context, vchannel string, importMsg *msgpb.ImportMsg) error {
return retry.Do(ctx, func() (err error) {
client, err := resource.Resource().MixCoordClient().GetWithContext(ctx)
if err != nil {
return err
}
importResp, err := client.ImportV2(ctx, &internalpb.ImportRequestInternal{
CollectionID: importMsg.GetCollectionID(),
CollectionName: importMsg.GetCollectionName(),
PartitionIDs: importMsg.GetPartitionIDs(),
ChannelNames: []string{vchannel},
Schema: importMsg.GetSchema(),
Files: lo.Map(importMsg.GetFiles(), util.ConvertInternalImportFile),
Options: funcutil.Map2KeyValuePair(importMsg.GetOptions()),
DataTimestamp: importMsg.GetBase().GetTimestamp(),
JobID: importMsg.GetJobID(),
})
err = merr.CheckRPCCall(importResp, err)
if errors.Is(err, merr.ErrCollectionNotFound) {
log.Ctx(ctx).Warn("import message failed because of collection not found, skip it", zap.String("job_id", importResp.GetJobID()), zap.Error(err))
return nil
}
if err != nil {
log.Ctx(ctx).Warn("import message failed", zap.String("job_id", importResp.GetJobID()), zap.Error(err))
return err
}
log.Ctx(ctx).Info("import message handled", zap.String("job_id", importResp.GetJobID()))
return nil
}, retry.AttemptAlways())
}

View File

@ -14,6 +14,7 @@ packages:
ImmutableMessage:
ImmutableTxnMessage:
MutableMessage:
BroadcastMutableMessage:
RProperties:
github.com/milvus-io/milvus/pkg/v2/streaming/walimpls:
interfaces:

View File

@ -0,0 +1,636 @@
// Code generated by mockery v2.53.3. DO NOT EDIT.
package mock_message
import (
message "github.com/milvus-io/milvus/pkg/v2/streaming/util/message"
mock "github.com/stretchr/testify/mock"
zapcore "go.uber.org/zap/zapcore"
)
// MockBroadcastMutableMessage is an autogenerated mock type for the BroadcastMutableMessage type
type MockBroadcastMutableMessage struct {
mock.Mock
}
type MockBroadcastMutableMessage_Expecter struct {
mock *mock.Mock
}
func (_m *MockBroadcastMutableMessage) EXPECT() *MockBroadcastMutableMessage_Expecter {
return &MockBroadcastMutableMessage_Expecter{mock: &_m.Mock}
}
// BarrierTimeTick provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) BarrierTimeTick() uint64 {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for BarrierTimeTick")
}
var r0 uint64
if rf, ok := ret.Get(0).(func() uint64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(uint64)
}
return r0
}
// MockBroadcastMutableMessage_BarrierTimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BarrierTimeTick'
type MockBroadcastMutableMessage_BarrierTimeTick_Call struct {
*mock.Call
}
// BarrierTimeTick is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) BarrierTimeTick() *MockBroadcastMutableMessage_BarrierTimeTick_Call {
return &MockBroadcastMutableMessage_BarrierTimeTick_Call{Call: _e.mock.On("BarrierTimeTick")}
}
func (_c *MockBroadcastMutableMessage_BarrierTimeTick_Call) Run(run func()) *MockBroadcastMutableMessage_BarrierTimeTick_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_BarrierTimeTick_Call) Return(_a0 uint64) *MockBroadcastMutableMessage_BarrierTimeTick_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_BarrierTimeTick_Call) RunAndReturn(run func() uint64) *MockBroadcastMutableMessage_BarrierTimeTick_Call {
_c.Call.Return(run)
return _c
}
// BroadcastHeader provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) BroadcastHeader() *message.BroadcastHeader {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for BroadcastHeader")
}
var r0 *message.BroadcastHeader
if rf, ok := ret.Get(0).(func() *message.BroadcastHeader); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*message.BroadcastHeader)
}
}
return r0
}
// MockBroadcastMutableMessage_BroadcastHeader_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'BroadcastHeader'
type MockBroadcastMutableMessage_BroadcastHeader_Call struct {
*mock.Call
}
// BroadcastHeader is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) BroadcastHeader() *MockBroadcastMutableMessage_BroadcastHeader_Call {
return &MockBroadcastMutableMessage_BroadcastHeader_Call{Call: _e.mock.On("BroadcastHeader")}
}
func (_c *MockBroadcastMutableMessage_BroadcastHeader_Call) Run(run func()) *MockBroadcastMutableMessage_BroadcastHeader_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_BroadcastHeader_Call) Return(_a0 *message.BroadcastHeader) *MockBroadcastMutableMessage_BroadcastHeader_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_BroadcastHeader_Call) RunAndReturn(run func() *message.BroadcastHeader) *MockBroadcastMutableMessage_BroadcastHeader_Call {
_c.Call.Return(run)
return _c
}
// EstimateSize provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) EstimateSize() int {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for EstimateSize")
}
var r0 int
if rf, ok := ret.Get(0).(func() int); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(int)
}
return r0
}
// MockBroadcastMutableMessage_EstimateSize_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'EstimateSize'
type MockBroadcastMutableMessage_EstimateSize_Call struct {
*mock.Call
}
// EstimateSize is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) EstimateSize() *MockBroadcastMutableMessage_EstimateSize_Call {
return &MockBroadcastMutableMessage_EstimateSize_Call{Call: _e.mock.On("EstimateSize")}
}
func (_c *MockBroadcastMutableMessage_EstimateSize_Call) Run(run func()) *MockBroadcastMutableMessage_EstimateSize_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_EstimateSize_Call) Return(_a0 int) *MockBroadcastMutableMessage_EstimateSize_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_EstimateSize_Call) RunAndReturn(run func() int) *MockBroadcastMutableMessage_EstimateSize_Call {
_c.Call.Return(run)
return _c
}
// IsPersisted provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) IsPersisted() bool {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for IsPersisted")
}
var r0 bool
if rf, ok := ret.Get(0).(func() bool); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(bool)
}
return r0
}
// MockBroadcastMutableMessage_IsPersisted_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IsPersisted'
type MockBroadcastMutableMessage_IsPersisted_Call struct {
*mock.Call
}
// IsPersisted is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) IsPersisted() *MockBroadcastMutableMessage_IsPersisted_Call {
return &MockBroadcastMutableMessage_IsPersisted_Call{Call: _e.mock.On("IsPersisted")}
}
func (_c *MockBroadcastMutableMessage_IsPersisted_Call) Run(run func()) *MockBroadcastMutableMessage_IsPersisted_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_IsPersisted_Call) Return(_a0 bool) *MockBroadcastMutableMessage_IsPersisted_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_IsPersisted_Call) RunAndReturn(run func() bool) *MockBroadcastMutableMessage_IsPersisted_Call {
_c.Call.Return(run)
return _c
}
// MarshalLogObject provides a mock function with given fields: _a0
func (_m *MockBroadcastMutableMessage) MarshalLogObject(_a0 zapcore.ObjectEncoder) error {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for MarshalLogObject")
}
var r0 error
if rf, ok := ret.Get(0).(func(zapcore.ObjectEncoder) error); ok {
r0 = rf(_a0)
} else {
r0 = ret.Error(0)
}
return r0
}
// MockBroadcastMutableMessage_MarshalLogObject_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MarshalLogObject'
type MockBroadcastMutableMessage_MarshalLogObject_Call struct {
*mock.Call
}
// MarshalLogObject is a helper method to define mock.On call
// - _a0 zapcore.ObjectEncoder
func (_e *MockBroadcastMutableMessage_Expecter) MarshalLogObject(_a0 interface{}) *MockBroadcastMutableMessage_MarshalLogObject_Call {
return &MockBroadcastMutableMessage_MarshalLogObject_Call{Call: _e.mock.On("MarshalLogObject", _a0)}
}
func (_c *MockBroadcastMutableMessage_MarshalLogObject_Call) Run(run func(_a0 zapcore.ObjectEncoder)) *MockBroadcastMutableMessage_MarshalLogObject_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(zapcore.ObjectEncoder))
})
return _c
}
func (_c *MockBroadcastMutableMessage_MarshalLogObject_Call) Return(_a0 error) *MockBroadcastMutableMessage_MarshalLogObject_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_MarshalLogObject_Call) RunAndReturn(run func(zapcore.ObjectEncoder) error) *MockBroadcastMutableMessage_MarshalLogObject_Call {
_c.Call.Return(run)
return _c
}
// MessageType provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) MessageType() message.MessageType {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for MessageType")
}
var r0 message.MessageType
if rf, ok := ret.Get(0).(func() message.MessageType); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(message.MessageType)
}
return r0
}
// MockBroadcastMutableMessage_MessageType_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'MessageType'
type MockBroadcastMutableMessage_MessageType_Call struct {
*mock.Call
}
// MessageType is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) MessageType() *MockBroadcastMutableMessage_MessageType_Call {
return &MockBroadcastMutableMessage_MessageType_Call{Call: _e.mock.On("MessageType")}
}
func (_c *MockBroadcastMutableMessage_MessageType_Call) Run(run func()) *MockBroadcastMutableMessage_MessageType_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_MessageType_Call) Return(_a0 message.MessageType) *MockBroadcastMutableMessage_MessageType_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_MessageType_Call) RunAndReturn(run func() message.MessageType) *MockBroadcastMutableMessage_MessageType_Call {
_c.Call.Return(run)
return _c
}
// Payload provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) Payload() []byte {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Payload")
}
var r0 []byte
if rf, ok := ret.Get(0).(func() []byte); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]byte)
}
}
return r0
}
// MockBroadcastMutableMessage_Payload_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Payload'
type MockBroadcastMutableMessage_Payload_Call struct {
*mock.Call
}
// Payload is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) Payload() *MockBroadcastMutableMessage_Payload_Call {
return &MockBroadcastMutableMessage_Payload_Call{Call: _e.mock.On("Payload")}
}
func (_c *MockBroadcastMutableMessage_Payload_Call) Run(run func()) *MockBroadcastMutableMessage_Payload_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_Payload_Call) Return(_a0 []byte) *MockBroadcastMutableMessage_Payload_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_Payload_Call) RunAndReturn(run func() []byte) *MockBroadcastMutableMessage_Payload_Call {
_c.Call.Return(run)
return _c
}
// Properties provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) Properties() message.RProperties {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Properties")
}
var r0 message.RProperties
if rf, ok := ret.Get(0).(func() message.RProperties); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(message.RProperties)
}
}
return r0
}
// MockBroadcastMutableMessage_Properties_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Properties'
type MockBroadcastMutableMessage_Properties_Call struct {
*mock.Call
}
// Properties is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) Properties() *MockBroadcastMutableMessage_Properties_Call {
return &MockBroadcastMutableMessage_Properties_Call{Call: _e.mock.On("Properties")}
}
func (_c *MockBroadcastMutableMessage_Properties_Call) Run(run func()) *MockBroadcastMutableMessage_Properties_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_Properties_Call) Return(_a0 message.RProperties) *MockBroadcastMutableMessage_Properties_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_Properties_Call) RunAndReturn(run func() message.RProperties) *MockBroadcastMutableMessage_Properties_Call {
_c.Call.Return(run)
return _c
}
// SplitIntoMutableMessage provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) SplitIntoMutableMessage() []message.MutableMessage {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for SplitIntoMutableMessage")
}
var r0 []message.MutableMessage
if rf, ok := ret.Get(0).(func() []message.MutableMessage); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]message.MutableMessage)
}
}
return r0
}
// MockBroadcastMutableMessage_SplitIntoMutableMessage_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SplitIntoMutableMessage'
type MockBroadcastMutableMessage_SplitIntoMutableMessage_Call struct {
*mock.Call
}
// SplitIntoMutableMessage is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) SplitIntoMutableMessage() *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call {
return &MockBroadcastMutableMessage_SplitIntoMutableMessage_Call{Call: _e.mock.On("SplitIntoMutableMessage")}
}
func (_c *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call) Run(run func()) *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call) Return(_a0 []message.MutableMessage) *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call) RunAndReturn(run func() []message.MutableMessage) *MockBroadcastMutableMessage_SplitIntoMutableMessage_Call {
_c.Call.Return(run)
return _c
}
// TimeTick provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) TimeTick() uint64 {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for TimeTick")
}
var r0 uint64
if rf, ok := ret.Get(0).(func() uint64); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(uint64)
}
return r0
}
// MockBroadcastMutableMessage_TimeTick_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TimeTick'
type MockBroadcastMutableMessage_TimeTick_Call struct {
*mock.Call
}
// TimeTick is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) TimeTick() *MockBroadcastMutableMessage_TimeTick_Call {
return &MockBroadcastMutableMessage_TimeTick_Call{Call: _e.mock.On("TimeTick")}
}
func (_c *MockBroadcastMutableMessage_TimeTick_Call) Run(run func()) *MockBroadcastMutableMessage_TimeTick_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_TimeTick_Call) Return(_a0 uint64) *MockBroadcastMutableMessage_TimeTick_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_TimeTick_Call) RunAndReturn(run func() uint64) *MockBroadcastMutableMessage_TimeTick_Call {
_c.Call.Return(run)
return _c
}
// TxnContext provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) TxnContext() *message.TxnContext {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for TxnContext")
}
var r0 *message.TxnContext
if rf, ok := ret.Get(0).(func() *message.TxnContext); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*message.TxnContext)
}
}
return r0
}
// MockBroadcastMutableMessage_TxnContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'TxnContext'
type MockBroadcastMutableMessage_TxnContext_Call struct {
*mock.Call
}
// TxnContext is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) TxnContext() *MockBroadcastMutableMessage_TxnContext_Call {
return &MockBroadcastMutableMessage_TxnContext_Call{Call: _e.mock.On("TxnContext")}
}
func (_c *MockBroadcastMutableMessage_TxnContext_Call) Run(run func()) *MockBroadcastMutableMessage_TxnContext_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_TxnContext_Call) Return(_a0 *message.TxnContext) *MockBroadcastMutableMessage_TxnContext_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_TxnContext_Call) RunAndReturn(run func() *message.TxnContext) *MockBroadcastMutableMessage_TxnContext_Call {
_c.Call.Return(run)
return _c
}
// Version provides a mock function with no fields
func (_m *MockBroadcastMutableMessage) Version() message.Version {
ret := _m.Called()
if len(ret) == 0 {
panic("no return value specified for Version")
}
var r0 message.Version
if rf, ok := ret.Get(0).(func() message.Version); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(message.Version)
}
return r0
}
// MockBroadcastMutableMessage_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version'
type MockBroadcastMutableMessage_Version_Call struct {
*mock.Call
}
// Version is a helper method to define mock.On call
func (_e *MockBroadcastMutableMessage_Expecter) Version() *MockBroadcastMutableMessage_Version_Call {
return &MockBroadcastMutableMessage_Version_Call{Call: _e.mock.On("Version")}
}
func (_c *MockBroadcastMutableMessage_Version_Call) Run(run func()) *MockBroadcastMutableMessage_Version_Call {
_c.Call.Run(func(args mock.Arguments) {
run()
})
return _c
}
func (_c *MockBroadcastMutableMessage_Version_Call) Return(_a0 message.Version) *MockBroadcastMutableMessage_Version_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_Version_Call) RunAndReturn(run func() message.Version) *MockBroadcastMutableMessage_Version_Call {
_c.Call.Return(run)
return _c
}
// WithBroadcastID provides a mock function with given fields: broadcastID
func (_m *MockBroadcastMutableMessage) WithBroadcastID(broadcastID uint64) message.BroadcastMutableMessage {
ret := _m.Called(broadcastID)
if len(ret) == 0 {
panic("no return value specified for WithBroadcastID")
}
var r0 message.BroadcastMutableMessage
if rf, ok := ret.Get(0).(func(uint64) message.BroadcastMutableMessage); ok {
r0 = rf(broadcastID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(message.BroadcastMutableMessage)
}
}
return r0
}
// MockBroadcastMutableMessage_WithBroadcastID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'WithBroadcastID'
type MockBroadcastMutableMessage_WithBroadcastID_Call struct {
*mock.Call
}
// WithBroadcastID is a helper method to define mock.On call
// - broadcastID uint64
func (_e *MockBroadcastMutableMessage_Expecter) WithBroadcastID(broadcastID interface{}) *MockBroadcastMutableMessage_WithBroadcastID_Call {
return &MockBroadcastMutableMessage_WithBroadcastID_Call{Call: _e.mock.On("WithBroadcastID", broadcastID)}
}
func (_c *MockBroadcastMutableMessage_WithBroadcastID_Call) Run(run func(broadcastID uint64)) *MockBroadcastMutableMessage_WithBroadcastID_Call {
_c.Call.Run(func(args mock.Arguments) {
run(args[0].(uint64))
})
return _c
}
func (_c *MockBroadcastMutableMessage_WithBroadcastID_Call) Return(_a0 message.BroadcastMutableMessage) *MockBroadcastMutableMessage_WithBroadcastID_Call {
_c.Call.Return(_a0)
return _c
}
func (_c *MockBroadcastMutableMessage_WithBroadcastID_Call) RunAndReturn(run func(uint64) message.BroadcastMutableMessage) *MockBroadcastMutableMessage_WithBroadcastID_Call {
_c.Call.Return(run)
return _c
}
// NewMockBroadcastMutableMessage creates a new instance of MockBroadcastMutableMessage. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
// The first argument is typically a *testing.T value.
func NewMockBroadcastMutableMessage(t interface {
mock.TestingT
Cleanup(func())
}) *MockBroadcastMutableMessage {
mock := &MockBroadcastMutableMessage{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}

View File

@ -167,6 +167,9 @@ type specializedMutableMessage[H proto.Message, B proto.Message] interface {
// !!! Do these will trigger a unmarshal operation, so it should be used with caution.
Body() (B, error)
// MustBody return the message body, panic if error occurs.
MustBody() B
// OverwriteHeader overwrites the message header.
OverwriteHeader(header H)
}

View File

@ -74,6 +74,8 @@ func TestBroadcast(t *testing.T) {
assert.Equal(t, uint64(1), msgs[1].BroadcastHeader().BroadcastID)
assert.Len(t, msgs[0].BroadcastHeader().ResourceKeys, 2)
assert.ElementsMatch(t, []string{"v1", "v2"}, []string{msgs[0].VChannel(), msgs[1].VChannel()})
MustAsMutableCreateCollectionMessageV1(msg)
}
func TestCiper(t *testing.T) {

View File

@ -349,6 +349,15 @@ func (m *specializedMutableMessageImpl[H, B]) Body() (B, error) {
return unmarshalProtoB[B](m.Payload())
}
// MustBody returns the message body.
func (m *specializedMutableMessageImpl[H, B]) MustBody() B {
b, err := m.Body()
if err != nil {
panic(fmt.Sprintf("failed to unmarshal specialized body,%s", err.Error()))
}
return b
}
// OverwriteMessageHeader overwrites the message header.
func (m *specializedMutableMessageImpl[H, B]) OverwriteHeader(header H) {
m.header = header