diff --git a/configs/milvus.yaml b/configs/milvus.yaml index 8a6d9f4505..8b227c9f1c 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1483,3 +1483,4 @@ function: url: # Your voyageai rerank url, Default is the official rerank url analyzer: local_resource_path: /var/lib/milvus/analyzer + concurrency_per_cpu_core: 8 # The concurrency per cpu core for analyzer, pipeline not included diff --git a/internal/util/function/bm25_function.go b/internal/util/function/bm25_function.go index 71c4440ce0..ffb9e7a726 100644 --- a/internal/util/function/bm25_function.go +++ b/internal/util/function/bm25_function.go @@ -24,16 +24,41 @@ import ( "github.com/cockroachdb/errors" "github.com/samber/lo" + "go.uber.org/zap" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/analyzer" + "github.com/milvus-io/milvus/pkg/v2/log" + "github.com/milvus-io/milvus/pkg/v2/util/conc" + "github.com/milvus-io/milvus/pkg/v2/util/hardware" "github.com/milvus-io/milvus/pkg/v2/util/merr" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) const analyzerParams = "analyzer_params" +var ( + analyzerPool *conc.Pool[struct{}] + analyzerPoolInitOnce sync.Once +) + +func initAnalyzerPool() { + cpuNum := hardware.GetCPUNum() + initPoolSize := int(float64(cpuNum) * paramtable.Get().FunctionCfg.AnalyzerConcurrencyPerCPUCore.GetAsFloat()) + if initPoolSize <= 0 { + log.Warn("analyzer pool size is less than 0, set to cpu num", zap.Int("cpuNum", cpuNum)) + initPoolSize = cpuNum + } + analyzerPool = conc.NewPool[struct{}](initPoolSize) +} + +func getOrCreateAnalyzerPool() *conc.Pool[struct{}] { + analyzerPoolInitOnce.Do(initAnalyzerPool) + return analyzerPool +} + type Analyzer interface { BatchAnalyze(withDetail bool, withHash bool, inputs ...any) ([][]*milvuspb.AnalyzerToken, error) GetInputFields() []*schemapb.FieldSchema @@ -247,33 +272,25 @@ func (v *BM25FunctionRunner) BatchAnalyze(withDetail bool, withHash bool, inputs rowNum := len(text) result := make([][]*milvuspb.AnalyzerToken, rowNum) - wg := sync.WaitGroup{} + pool := getOrCreateAnalyzerPool() + futures := make([]*conc.Future[struct{}], 0, v.concurrency) - errCh := make(chan error, v.concurrency) for i, j := 0, 0; i < v.concurrency && j < rowNum; i++ { start := j end := start + rowNum/v.concurrency if i < rowNum%v.concurrency { end += 1 } - wg.Add(1) - go func() { - defer wg.Done() - err := v.analyze(text[start:end], result[start:end], withDetail, withHash) - if err != nil { - errCh <- err - return - } - }() + future := pool.Submit(func() (struct{}, error) { + return struct{}{}, v.analyze(text[start:end], result[start:end], withDetail, withHash) + }) + futures = append(futures, future) j = end } - wg.Wait() - close(errCh) - for err := range errCh { - if err != nil { - return nil, err - } + err := conc.AwaitAll(futures...) + if err != nil { + return nil, err } return result, nil } diff --git a/internal/util/function/bm25_function_test.go b/internal/util/function/bm25_function_test.go index 60e54a2579..b77cb49391 100644 --- a/internal/util/function/bm25_function_test.go +++ b/internal/util/function/bm25_function_test.go @@ -46,7 +46,7 @@ func (s *BM25FunctionRunnerSuite) SetupTest() { } } -func (s *BM25FunctionRunnerSuite) TestBM25() { +func (s *BM25FunctionRunnerSuite) TestBatchRun() { _, err := NewFunctionRunner(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_BM25, @@ -86,3 +86,23 @@ func (s *BM25FunctionRunnerSuite) TestBM25() { _, err = runner.BatchRun([]string{"test string", "test string 2"}) s.Error(err) } + +func (s *BM25FunctionRunnerSuite) TestBatchAnalyze() { + runner, err := NewFunctionRunner(s.schema, &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_BM25, + InputFieldIds: []int64{101}, + OutputFieldIds: []int64{102}, + }) + s.NoError(err) + + analyzer, ok := runner.(Analyzer) + s.True(ok) + + result, err := analyzer.BatchAnalyze(true, false, []string{"test string", "test string 2"}) + s.NoError(err) + + s.Equal(2, len(result)) + s.Equal(2, len(result[0])) + s.Equal(3, len(result[1])) +} diff --git a/internal/util/function/multi_analyzer_bm25_function.go b/internal/util/function/multi_analyzer_bm25_function.go index 535f62c70a..5364132c73 100644 --- a/internal/util/function/multi_analyzer_bm25_function.go +++ b/internal/util/function/multi_analyzer_bm25_function.go @@ -26,6 +26,7 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/util/analyzer" + "github.com/milvus-io/milvus/pkg/v2/util/conc" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) @@ -308,33 +309,25 @@ func (v *MultiAnalyzerBM25FunctionRunner) BatchAnalyze(withDetail bool, withHash rowNum := len(text) result := make([][]*milvuspb.AnalyzerToken, rowNum) - wg := sync.WaitGroup{} + pool := getOrCreateAnalyzerPool() + futures := make([]*conc.Future[struct{}], 0) - errCh := make(chan error, v.concurrency) for i, j := 0, 0; i < v.concurrency && j < rowNum; i++ { start := j end := start + rowNum/v.concurrency if i < rowNum%v.concurrency { end += 1 } - wg.Add(1) - go func() { - defer wg.Done() - err := v.analyze(text[start:end], analyzer[start:end], result[start:end], withDetail, withHash) - if err != nil { - errCh <- err - return - } - }() + future := pool.Submit(func() (struct{}, error) { + return struct{}{}, v.analyze(text[start:end], analyzer[start:end], result[start:end], withDetail, withHash) + }) + futures = append(futures, future) j = end } - wg.Wait() - close(errCh) - for err := range errCh { - if err != nil { - return nil, err - } + err := conc.AwaitAll(futures...) + if err != nil { + return nil, err } return result, nil } diff --git a/internal/util/function/multi_analyzer_bm25_function_test.go b/internal/util/function/multi_analyzer_bm25_function_test.go index 819637d61a..959541ea9b 100644 --- a/internal/util/function/multi_analyzer_bm25_function_test.go +++ b/internal/util/function/multi_analyzer_bm25_function_test.go @@ -164,6 +164,27 @@ func (s *MultiAnalyzerBM25FunctionSuite) TestBatchRun() { }) } +func (s *MultiAnalyzerBM25FunctionSuite) TestBatchAnalyze() { + s.Run("normal", func() { + runner, err := NewBM25FunctionRunner(s.collection, s.function) + s.NoError(err) + s.NotNil(runner) + + analyzer, ok := runner.(Analyzer) + s.True(ok) + + text := []string{"test of analyzer", "test of analyzer"} + analyzerName := []string{"english", "default"} + + result, err := analyzer.BatchAnalyze(true, false, text, analyzerName) + s.NoError(err) + + s.Equal(2, len(result)) + s.Equal(2, len(result[0])) + s.Equal(3, len(result[1])) + }) +} + func TestMultiAnalyzerBm25Function(t *testing.T) { suite.Run(t, new(MultiAnalyzerBM25FunctionSuite)) } diff --git a/pkg/util/paramtable/function_param.go b/pkg/util/paramtable/function_param.go index 40fb24b71c..af54460143 100644 --- a/pkg/util/paramtable/function_param.go +++ b/pkg/util/paramtable/function_param.go @@ -21,12 +21,13 @@ import ( ) type functionConfig struct { - BatchFactor ParamItem `refreshable:"true"` - TextEmbeddingProviders ParamGroup `refreshable:"true"` - RerankModelProviders ParamGroup `refreshable:"true"` - LocalResourcePath ParamItem `refreshable:"true"` - LinderaDownloadUrls ParamGroup `refreshable:"true"` - ZillizProviders ParamGroup `refreshable:"true"` + BatchFactor ParamItem `refreshable:"true"` + TextEmbeddingProviders ParamGroup `refreshable:"true"` + RerankModelProviders ParamGroup `refreshable:"true"` + LocalResourcePath ParamItem `refreshable:"true"` + LinderaDownloadUrls ParamGroup `refreshable:"true"` + ZillizProviders ParamGroup `refreshable:"true"` + AnalyzerConcurrencyPerCPUCore ParamItem `refreshable:"true"` } func (p *functionConfig) init(base *BaseTable) { @@ -160,6 +161,15 @@ func (p *functionConfig) init(base *BaseTable) { Version: "2.6.5", } p.ZillizProviders.Init(base.mgr) + + p.AnalyzerConcurrencyPerCPUCore = ParamItem{ + Key: "function.analyzer.concurrency_per_cpu_core", + Version: "2.6.8", + Export: true, + Doc: "The concurrency per cpu core for analyzer, pipeline not included", + DefaultValue: "8", + } + p.AnalyzerConcurrencyPerCPUCore.Init(base.mgr) } func (p *functionConfig) GetTextEmbeddingProviderConfig(providerName string) map[string]string {