From 13330bd46624b8768681be0aa35dfd9f4bcd3ffa Mon Sep 17 00:00:00 2001 From: aoiasd <45024769+aoiasd@users.noreply.github.com> Date: Tue, 10 Jun 2025 11:36:34 +0800 Subject: [PATCH] fix: add concurrency and close protect for bm25 function (#42597) relate: https://github.com/milvus-io/milvus/issues/42576 Signed-off-by: aoiasd --- internal/querynodev2/delegator/delegator.go | 6 ++++++ internal/util/function/bm25_function.go | 21 +++++++++++++++++++ ...function_test.go => bm25_function_test.go} | 16 +++++++++----- .../function/multi_analyzer_bm25_function.go | 21 +++++++++++++++++++ .../multi_analyzer_bm25_function_test.go | 6 ++++++ 5 files changed, 65 insertions(+), 5 deletions(-) rename internal/util/function/{function_test.go => bm25_function_test.go} (86%) diff --git a/internal/querynodev2/delegator/delegator.go b/internal/querynodev2/delegator/delegator.go index 10c4bd6744..ba4b2f94ad 100644 --- a/internal/querynodev2/delegator/delegator.go +++ b/internal/querynodev2/delegator/delegator.go @@ -1043,6 +1043,12 @@ func (sd *shardDelegator) Close() { sd.idfOracle.Close() } + if sd.functionRunners != nil { + for _, function := range sd.functionRunners { + function.Close() + } + } + // clean up l0 segment in delete buffer start := time.Now() sd.deleteBuffer.Clear() diff --git a/internal/util/function/bm25_function.go b/internal/util/function/bm25_function.go index f03fcb23a1..4a2dd2b923 100644 --- a/internal/util/function/bm25_function.go +++ b/internal/util/function/bm25_function.go @@ -43,6 +43,9 @@ type Analyzer interface { // Input: string // Output: map[uint32]float32 type BM25FunctionRunner struct { + mu sync.RWMutex + closed bool + tokenizer tokenizerapi.Tokenizer schema *schemapb.FunctionSchema outputField *schemapb.FieldSchema @@ -122,6 +125,13 @@ func (v *BM25FunctionRunner) run(data []string, dst []map[uint32]float32) error } func (v *BM25FunctionRunner) BatchRun(inputs ...any) ([]any, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + if v.closed { + return nil, errors.New("analyzer receview request after function closed") + } + if len(inputs) > 1 { return nil, errors.New("BM25 function received more than one input column") } @@ -197,6 +207,13 @@ func (v *BM25FunctionRunner) analyze(data []string, dst [][]*milvuspb.AnalyzerTo } func (v *BM25FunctionRunner) BatchAnalyze(withDetail bool, withHash bool, inputs ...any) ([][]*milvuspb.AnalyzerToken, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + if v.closed { + return nil, errors.New("analyzer receview request after function closed") + } + if len(inputs) > 1 { return nil, errors.New("analyze received should only receive text input column(not set analyzer name)") } @@ -252,6 +269,10 @@ func (v *BM25FunctionRunner) GetInputFields() []*schemapb.FieldSchema { } func (v *BM25FunctionRunner) Close() { + v.mu.Lock() + defer v.mu.Unlock() + + v.closed = true v.tokenizer.Destroy() } diff --git a/internal/util/function/function_test.go b/internal/util/function/bm25_function_test.go similarity index 86% rename from internal/util/function/function_test.go rename to internal/util/function/bm25_function_test.go index 964ec3fc8e..60e54a2579 100644 --- a/internal/util/function/function_test.go +++ b/internal/util/function/bm25_function_test.go @@ -26,16 +26,16 @@ import ( "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" ) -func TestFunctionRunnerSuite(t *testing.T) { - suite.Run(t, new(FunctionRunnerSuite)) +func TestBM25FunctionRunnerSuite(t *testing.T) { + suite.Run(t, new(BM25FunctionRunnerSuite)) } -type FunctionRunnerSuite struct { +type BM25FunctionRunnerSuite struct { suite.Suite schema *schemapb.CollectionSchema } -func (s *FunctionRunnerSuite) SetupTest() { +func (s *BM25FunctionRunnerSuite) SetupTest() { s.schema = &schemapb.CollectionSchema{ Name: "test", Fields: []*schemapb.FieldSchema{ @@ -46,7 +46,7 @@ func (s *FunctionRunnerSuite) SetupTest() { } } -func (s *FunctionRunnerSuite) TestBM25() { +func (s *BM25FunctionRunnerSuite) TestBM25() { _, err := NewFunctionRunner(s.schema, &schemapb.FunctionSchema{ Name: "test", Type: schemapb.FunctionType_BM25, @@ -79,4 +79,10 @@ func (s *FunctionRunnerSuite) TestBM25() { // return error because field not string _, err = runner.BatchRun([]int64{}) s.Error(err) + + runner.Close() + + // run after close + _, err = runner.BatchRun([]string{"test string", "test string 2"}) + s.Error(err) } diff --git a/internal/util/function/multi_analyzer_bm25_function.go b/internal/util/function/multi_analyzer_bm25_function.go index 9e24a6385b..90065734e2 100644 --- a/internal/util/function/multi_analyzer_bm25_function.go +++ b/internal/util/function/multi_analyzer_bm25_function.go @@ -36,6 +36,9 @@ const multiAnalyzerParams = "multi_analyzer_params" // Input: string string // text, analyzer name // Output: map[uint32]float32 type MultiAnalyzerBM25FunctionRunner struct { + mu sync.RWMutex + closed bool + analyzers map[string]tokenizerapi.Tokenizer alias map[string]string // alias -> analyzer name schema *schemapb.FunctionSchema @@ -173,6 +176,13 @@ func (v *MultiAnalyzerBM25FunctionRunner) run(text []string, analyzerName []stri } func (v *MultiAnalyzerBM25FunctionRunner) BatchRun(inputs ...any) ([]any, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + if v.closed { + return nil, fmt.Errorf("analyzer receview request after function closed") + } + if len(inputs) != 2 { return nil, fmt.Errorf("BM25 function with multi analyzer must received two input column") } @@ -263,6 +273,13 @@ func (v *MultiAnalyzerBM25FunctionRunner) analyze(data []string, analyzerName [] } func (v *MultiAnalyzerBM25FunctionRunner) BatchAnalyze(withDetail bool, withHash bool, inputs ...any) ([][]*milvuspb.AnalyzerToken, error) { + v.mu.RLock() + defer v.mu.RUnlock() + + if v.closed { + return nil, fmt.Errorf("analyzer receview request after function closed") + } + if len(inputs) != 2 { return nil, fmt.Errorf("multi analyzer must received two input column(text, analyzer_name)") } @@ -327,7 +344,11 @@ func (v *MultiAnalyzerBM25FunctionRunner) GetInputFields() []*schemapb.FieldSche } func (v *MultiAnalyzerBM25FunctionRunner) Close() { + v.mu.Lock() + defer v.mu.Unlock() + for _, analyzer := range v.analyzers { analyzer.Destroy() } + v.closed = true } diff --git a/internal/util/function/multi_analyzer_bm25_function_test.go b/internal/util/function/multi_analyzer_bm25_function_test.go index ef5cac28a5..819637d61a 100644 --- a/internal/util/function/multi_analyzer_bm25_function_test.go +++ b/internal/util/function/multi_analyzer_bm25_function_test.go @@ -155,6 +155,12 @@ func (s *MultiAnalyzerBM25FunctionSuite) TestBatchRun() { s.Equal(16, len(sparseArray.GetContents()[0])) // bytes size will be 3 * 2 * 4 = 24 s.Equal(24, len(sparseArray.GetContents()[1])) + + runner.Close() + + // run after close + _, err = runner.BatchRun(text, analyzerName) + s.Error(err) }) }