From 4202c775bad3bf5a04bc9533cdc240a4c0904429 Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Wed, 28 May 2025 19:18:28 +0800 Subject: [PATCH] feat: Support vllm and tei rerank (#41947) https://github.com/milvus-io/milvus/issues/35856 Signed-off-by: junjie.jiang --- configs/milvus.yaml | 7 + internal/proxy/task_search.go | 38 +- internal/util/function/common.go | 7 +- .../ali/ali_dashscope_text_embedding.go | 2 +- .../models/cohere/cohere_text_embedding.go | 2 +- .../models/openai/openai_embedding.go | 2 +- .../models/openai/openai_embedding_test.go | 4 +- .../siliconflow/siliconflow_text_embedding.go | 2 +- internal/util/function/models/tei/tei.go | 2 +- .../function/models/utils/embedding_util.go | 7 +- .../vertexai/vertexai_text_embedding.go | 2 +- .../voyageai/voyageai_text_embedding.go | 2 +- .../util/function/rerank/decay_function.go | 27 +- .../function/rerank/decay_function_test.go | 2 +- .../util/function/rerank/function_score.go | 5 +- .../util/function/rerank/model_function.go | 378 +++++++++++++ .../function/rerank/model_function_test.go | 509 ++++++++++++++++++ internal/util/function/rerank/rerank_base.go | 13 +- internal/util/function/rerank/util.go | 16 +- .../util/function/tei_embedding_provider.go | 4 +- .../function/tei_embedding_provider_test.go | 8 +- pkg/util/paramtable/function_param.go | 32 ++ .../test_milvus_client_search.py | 8 +- 23 files changed, 1013 insertions(+), 66 deletions(-) create mode 100644 internal/util/function/rerank/model_function.go create mode 100644 internal/util/function/rerank/model_function_test.go diff --git a/configs/milvus.yaml b/configs/milvus.yaml index b7db571627..5031bfa22f 100644 --- a/configs/milvus.yaml +++ b/configs/milvus.yaml @@ -1327,3 +1327,10 @@ function: voyageai: credential: # The name in the crendential configuration item url: # Your voyageai embedding url, Default is the official embedding url + rerank: + model: + providers: + tei: + enable: true # Whether to enable TEI rerank service + vllm: + enable: true # Whether to enable vllm rerank service diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 3c40701273..c4fc6589f9 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -646,6 +646,17 @@ func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) { func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error { var err error + processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf") + defer sp.End() + + params := rerank.NewSearchParams( + t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal, + t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics, + ) + return t.functionScore.Process(ctx, params, results) + } + // The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery. // At this time, outputFields and rerank input_fields will be recalled. // If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step @@ -682,12 +693,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult for i := 0; i < len(multipleMilvusResults); i++ { multipleMilvusResults[i].Results.FieldsData = fields[i] } - params := rerank.NewSearchParams( - t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal, - t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics, - ) - - if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil { + if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil { return err } if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil { @@ -696,11 +702,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult t.result.Results.FieldsData = fields[0] } } else { - params := rerank.NewSearchParams( - t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal, - t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics, - ) - if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil { + if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil { return err } } @@ -823,11 +825,15 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR } if t.functionScore != nil && len(result.Results.FieldsData) != 0 { - params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), - t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType}) - // rank only returns id and score - if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil { - return err + { + ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf") + defer sp.End() + params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(), + t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType}) + // rank only returns id and score + if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil { + return err + } } if !t.needRequery { fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids}) diff --git a/internal/util/function/common.go b/internal/util/function/common.go index 62d5c243b4..c81940059d 100644 --- a/internal/util/function/common.go +++ b/internal/util/function/common.go @@ -109,16 +109,17 @@ const ( siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY" ) -// TEI +// TEI and vllm const ( ingestionPromptParamKey string = "ingestion_prompt" searchPromptParamKey string = "search_prompt" maxClientBatchSizeParamKey string = "max_client_batch_size" truncationDirectionParamKey string = "truncation_direction" - endpointParamKey string = "endpoint" + EndpointParamKey string = "endpoint" - enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI" + EnableTeiEnvStr string = "MILVUSAI_ENABLE_TEI" + EnableVllmEnvStr string = "MILVUSAI_ENABLE_VLLM" ) func parseAKAndURL(credentials *credentials.Credentials, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) { diff --git a/internal/util/function/models/ali/ali_dashscope_text_embedding.go b/internal/util/function/models/ali/ali_dashscope_text_embedding.go index c55e2eb8ec..b82f4cbbd7 100644 --- a/internal/util/function/models/ali/ali_dashscope_text_embedding.go +++ b/internal/util/function/models/ali/ali_dashscope_text_embedding.go @@ -140,7 +140,7 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim "Content-Type": "application/json", "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), } - body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3) if err != nil { return nil, err } diff --git a/internal/util/function/models/cohere/cohere_text_embedding.go b/internal/util/function/models/cohere/cohere_text_embedding.go index 03d2c528c1..c153047900 100644 --- a/internal/util/function/models/cohere/cohere_text_embedding.go +++ b/internal/util/function/models/cohere/cohere_text_embedding.go @@ -106,7 +106,7 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType "Content-Type": "application/json", "Authorization": fmt.Sprintf("bearer %s", c.apiKey), } - body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3) if err != nil { return nil, err } diff --git a/internal/util/function/models/openai/openai_embedding.go b/internal/util/function/models/openai/openai_embedding.go index 8d65fef025..5fc8376c4d 100644 --- a/internal/util/function/models/openai/openai_embedding.go +++ b/internal/util/function/models/openai/openai_embedding.go @@ -149,7 +149,7 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second) defer cancel() - body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1) + body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3) if err != nil { return nil, err } diff --git a/internal/util/function/models/openai/openai_embedding_test.go b/internal/util/function/models/openai/openai_embedding_test.go index cb6f03a3e7..3134043ed3 100644 --- a/internal/util/function/models/openai/openai_embedding_test.go +++ b/internal/util/function/models/openai/openai_embedding_test.go @@ -218,8 +218,8 @@ func TestEmbeddingFailed(t *testing.T) { func TestTimeout(t *testing.T) { var st int32 = 0 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // (Timeout 1s + Wait 1s) * Retry 3 - time.Sleep(6 * time.Second) + // (Timeout 1s + 2s + 4s + Wait 1s * 3) + time.Sleep(11 * time.Second) atomic.AddInt32(&st, 1) w.WriteHeader(http.StatusUnauthorized) })) diff --git a/internal/util/function/models/siliconflow/siliconflow_text_embedding.go b/internal/util/function/models/siliconflow/siliconflow_text_embedding.go index 1cce1e4545..d7189f2aed 100644 --- a/internal/util/function/models/siliconflow/siliconflow_text_embedding.go +++ b/internal/util/function/models/siliconflow/siliconflow_text_embedding.go @@ -121,7 +121,7 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod "Content-Type": "application/json", "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), } - body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3) if err != nil { return nil, err } diff --git a/internal/util/function/models/tei/tei.go b/internal/util/function/models/tei/tei.go index 31386ad76e..f2d69f88cb 100644 --- a/internal/util/function/models/tei/tei.go +++ b/internal/util/function/models/tei/tei.go @@ -93,7 +93,7 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect if c.apiKey != "" { headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey) } - body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3) if err != nil { return nil, err } diff --git a/internal/util/function/models/utils/embedding_util.go b/internal/util/function/models/utils/embedding_util.go index d786842727..c0c7759c5b 100644 --- a/internal/util/function/models/utils/embedding_util.go +++ b/internal/util/function/models/utils/embedding_util.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "io" + "math/rand" "net/http" "time" ) @@ -45,7 +46,7 @@ func send(req *http.Request) ([]byte, error) { return body, nil } -func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) { +func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int) ([]byte, error) { var err error var body []byte for i := 0; i < maxRetries; i++ { @@ -60,7 +61,9 @@ func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, if err == nil { return body, nil } - time.Sleep(time.Duration(retryDelay) * time.Second) + backoffDelay := 1 << uint(i) * time.Second + jitter := time.Duration(rand.Int63n(int64(backoffDelay / 4))) + time.Sleep(backoffDelay + jitter) } return nil, err } diff --git a/internal/util/function/models/vertexai/vertexai_text_embedding.go b/internal/util/function/models/vertexai/vertexai_text_embedding.go index e2f5823f8d..893dc96b87 100644 --- a/internal/util/function/models/vertexai/vertexai_text_embedding.go +++ b/internal/util/function/models/vertexai/vertexai_text_embedding.go @@ -148,7 +148,7 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6 "Content-Type": "application/json", "Authorization": fmt.Sprintf("Bearer %s", token), } - body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3) if err != nil { return nil, err } diff --git a/internal/util/function/models/voyageai/voyageai_text_embedding.go b/internal/util/function/models/voyageai/voyageai_text_embedding.go index ca8ce74a6b..0a650e32b7 100644 --- a/internal/util/function/models/voyageai/voyageai_text_embedding.go +++ b/internal/util/function/models/voyageai/voyageai_text_embedding.go @@ -138,7 +138,7 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int, "Content-Type": "application/json", "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), } - body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1) + body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3) if err != nil { return nil, err } diff --git a/internal/util/function/rerank/decay_function.go b/internal/util/function/rerank/decay_function.go index df648db729..7ae9db1a72 100644 --- a/internal/util/function/rerank/decay_function.go +++ b/internal/util/function/rerank/decay_function.go @@ -55,34 +55,17 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct { } func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { - pkType := schemapb.DataType_None - for _, field := range collSchema.Fields { - if field.IsPrimaryKey { - pkType = field.DataType - } - } - - if pkType == schemapb.DataType_None { - return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name) - } - - base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType) + base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false) if err != nil { return nil, err } if len(base.GetInputFieldNames()) != 1 { - return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames()) + return nil, fmt.Errorf("Decay function only supports single input, but gets [%s] input", base.GetInputFieldNames()) } - var inputType schemapb.DataType - for _, field := range collSchema.Fields { - if field.Name == base.GetInputFieldNames()[0] { - inputType = field.DataType - } - } - - if pkType == schemapb.DataType_Int64 { + inputType := base.GetInputFieldTypes()[0] + if base.pkType == schemapb.DataType_Int64 { switch inputType { case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32: return newFunction[int64, int32](base, funcSchema) @@ -160,7 +143,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase } if decayFunc.decay <= 0 || decayFunc.decay >= 1 { - return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.offset) + return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.decay) } switch decayFunc.functionName { diff --git a/internal/util/function/rerank/decay_function_test.go b/internal/util/function/rerank/decay_function_test.go index 6b4214015c..deae1fdd02 100644 --- a/internal/util/function/rerank/decay_function_test.go +++ b/internal/util/function/rerank/decay_function_test.go @@ -90,7 +90,7 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() { { functionSchema.InputFieldNames = []string{"ts", "pk"} _, err := newDecayFunction(schema, functionSchema) - s.ErrorContains(err, "Decay function only supoorts single input, but gets") + s.ErrorContains(err, "Decay function only supports single input, but gets") } { diff --git a/internal/util/function/rerank/function_score.go b/internal/util/function/rerank/function_score.go index 64aaf904dd..ea15fdd8c3 100644 --- a/internal/util/function/rerank/function_score.go +++ b/internal/util/function/rerank/function_score.go @@ -32,6 +32,7 @@ import ( const ( decayFunctionName string = "decay" + modelFunctionName string = "model" ) type SearchParams struct { @@ -92,8 +93,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb. switch rerankerName { case decayFunctionName: rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema) + case modelFunctionName: + rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema) default: - return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName) + return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName) } if newRerankErr != nil { diff --git a/internal/util/function/rerank/model_function.go b/internal/util/function/rerank/model_function.go new file mode 100644 index 0000000000..79e0205196 --- /dev/null +++ b/internal/util/function/rerank/model_function.go @@ -0,0 +1,378 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package rerank + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "sort" + "strconv" + "strings" + "time" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function" + "github.com/milvus-io/milvus/internal/util/function/models/utils" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" +) + +const ( + providerParamName string = "provider" + vllmProviderName string = "vllm" + teiProviderName string = "tei" + + queryKeyName string = "queries" + maxBatchKeyName string = "max_batch" +) + +type modelProvider interface { + rerank(context.Context, string, []string) ([]float32, error) + getURL() string +} + +type baseModel struct { + url string + maxBatch int + + queryKey string + docKey string + + parseScores func([]byte) ([]float32, error) +} + +func (base *baseModel) getURL() string { + return base.url +} + +func (base *baseModel) rerank(ctx context.Context, query string, docs []string) ([]float32, error) { + requestBodies, err := genRerankRequestBody(query, docs, base.maxBatch, base.queryKey, base.docKey) + if err != nil { + return nil, err + } + scores := []float32{} + for _, requestBody := range requestBodies { + rerankResp, err := base.callService(ctx, requestBody, 30) + if err != nil { + return nil, fmt.Errorf("Call rerank model failed: %v\n", err) + } + scores = append(scores, rerankResp...) + } + + if len(scores) != len(docs) { + return nil, fmt.Errorf("Call Rerank service failed, %d docs but got %d scores", len(docs), len(scores)) + } + return scores, nil +} + +func (base *baseModel) callService(ctx context.Context, requestBody []byte, timeoutSec int64) ([]float32, error) { + ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second) + defer cancel() + headers := map[string]string{ + "Content-Type": "application/json", + } + body, err := utils.RetrySend(ctx, requestBody, http.MethodPost, base.url, headers, 3) + if err != nil { + return nil, err + } + return base.parseScores(body) +} + +type vllmRerankRequest struct { + Query string `json:"query"` + Documents []string `json:"documents"` +} + +type vllmRerankResponse struct { + ID string `json:"id"` + Model string `json:"model"` + Usage vllmUsage `json:"usage"` + Results []vllmResult `json:"results"` +} + +type vllmUsage struct { + TotalTokens int `json:"total_tokens"` +} + +type vllmResult struct { + Index int `json:"index"` + Document vllmDocument `json:"document"` + RelevanceScore float32 `json:"relevance_score"` +} + +type vllmDocument struct { + Text string `json:"text"` +} + +type vllmProvider struct { + baseModel +} + +func newVllmProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) { + if !isEnable(conf, function.EnableVllmEnvStr) { + return nil, fmt.Errorf("Vllm rerank is disabled") + } + endpoint, maxBatch, err := parseParams(params) + if err != nil { + return nil, err + } + base, _ := url.Parse(endpoint) + base.Path = "/v2/rerank" + model := baseModel{ + url: base.String(), + maxBatch: maxBatch, + queryKey: "query", + docKey: "documents", + parseScores: func(body []byte) ([]float32, error) { + var rerankResp vllmRerankResponse + if err := json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("Rerank error, parsing vllm response failed: %v", err) + } + + sort.Slice(rerankResp.Results, func(i, j int) bool { + return rerankResp.Results[i].Index < rerankResp.Results[j].Index + }) + + scores := make([]float32, 0, len(rerankResp.Results)) + for _, result := range rerankResp.Results { + scores = append(scores, result.RelevanceScore) + } + + return scores, nil + }, + } + return &vllmProvider{baseModel: model}, nil +} + +type teiProvider struct { + baseModel +} + +type TEIResponse struct { + Index int `json:"index"` + Score float32 `json:"score"` +} + +func newTeiProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) { + if !isEnable(conf, function.EnableTeiEnvStr) { + return nil, fmt.Errorf("TEI rerank is disabled") + } + endpoint, maxBatch, err := parseParams(params) + if err != nil { + return nil, err + } + base, _ := url.Parse(endpoint) + base.Path = "/rerank" + model := baseModel{ + url: base.String(), + maxBatch: maxBatch, + queryKey: "query", + docKey: "texts", + parseScores: func(body []byte) ([]float32, error) { + var results []TEIResponse + if err := json.Unmarshal(body, &results); err != nil { + return nil, fmt.Errorf("Rerank error, parsing TEI response failed: %v", err) + } + sort.Slice(results, func(i, j int) bool { + return results[i].Index < results[j].Index + }) + scores := make([]float32, len(results)) + for i, result := range results { + scores[i] = result.Score + } + return scores, nil + }, + } + return &teiProvider{baseModel: model}, nil +} + +func isEnable(conf map[string]string, envKey string) bool { + // milvus.yaml > env + value, exists := conf["enable"] + if exists { + return strings.ToLower(value) == "true" + } else { + return !(strings.ToLower(os.Getenv(envKey)) == "false") + } +} + +func parseParams(params []*commonpb.KeyValuePair) (string, int, error) { + endpoint := "" + maxBatch := 32 + for _, param := range params { + switch strings.ToLower(param.Key) { + case function.EndpointParamKey: + base, err := url.Parse(param.Value) + if err != nil { + return "", 0, err + } + if base.Scheme != "http" && base.Scheme != "https" { + return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value) + } + if base.Host == "" { + return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value) + } + endpoint = base.String() + case maxBatchKeyName: + if batch, err := strconv.ParseInt(param.Value, 10, 64); err != nil { + return "", 0, fmt.Errorf("Rerank params error, maxBatch: %s is not a number", param.Value) + } else { + maxBatch = int(batch) + } + } + } + if endpoint == "" { + return "", 0, fmt.Errorf("Rerank function lost params endpoint") + } + if maxBatch <= 0 { + return "", 0, fmt.Errorf("Rerank function params max_batch must > 0, but got %d", maxBatch) + } + return endpoint, maxBatch, nil +} + +func genRerankRequestBody(query string, documents []string, maxSize int, queryKey string, docKey string) ([][]byte, error) { + requestBodies := [][]byte{} + for i := 0; i < len(documents); i += maxSize { + end := i + maxSize + if end > len(documents) { + end = len(documents) + } + requestBody := map[string]interface{}{ + queryKey: query, + docKey: documents[i:end], + } + jsonData, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("Create model rerank request failed, err: %s", err) + } + requestBodies = append(requestBodies, jsonData) + } + return requestBodies, nil +} + +func newProvider(params []*commonpb.KeyValuePair) (modelProvider, error) { + for _, param := range params { + if strings.ToLower(param.Key) == providerParamName { + provider := strings.ToLower(param.Value) + conf := paramtable.Get().FunctionCfg.GetRerankModelProviders(provider) + switch provider { + case vllmProviderName: + return newVllmProvider(params, conf) + case teiProviderName: + return newTeiProvider(params, conf) + default: + return nil, fmt.Errorf("Unknow rerank provider:%s", param.Value) + } + } + } + return nil, fmt.Errorf("Lost rerank params:%s ", providerParamName) +} + +type ModelFunction[T PKType] struct { + RerankBase + + provider modelProvider + queries []string +} + +func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) { + base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false) + if err != nil { + return nil, err + } + + if len(base.GetInputFieldNames()) != 1 { + return nil, fmt.Errorf("Rerank model only supports single input, but gets [%s] input", base.GetInputFieldNames()) + } + + if base.GetInputFieldTypes()[0] != schemapb.DataType_VarChar { + return nil, fmt.Errorf("Rerank model only support varchar, bug got [%s]", base.GetInputFieldTypes()[0].String()) + } + + provider, err := newProvider(funcSchema.Params) + if err != nil { + return nil, err + } + + queries := []string{} + for _, param := range funcSchema.Params { + if param.Key == queryKeyName { + if err := json.Unmarshal([]byte(param.Value), &queries); err != nil { + return nil, fmt.Errorf("Parse rerank params [queries] failed, err: %v", err) + } + } + } + if len(queries) == 0 { + return nil, fmt.Errorf("Rerank function lost params queries") + } + + if base.pkType == schemapb.DataType_Int64 { + return &ModelFunction[int64]{RerankBase: *base, provider: provider, queries: queries}, nil + } else { + return &ModelFunction[string]{RerankBase: *base, provider: provider, queries: queries}, nil + } +} + +func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) { + uniqueData := make(map[T]string) + for _, col := range cols { + texts := col.data[0].([]string) + ids := col.ids.([]T) + for idx, id := range ids { + if _, ok := uniqueData[id]; !ok { + uniqueData[id] = texts[idx] + } + } + } + ids := make([]T, 0, len(uniqueData)) + texts := make([]string, 0, len(uniqueData)) + for id, text := range uniqueData { + ids = append(ids, id) + texts = append(texts, text) + } + scores, err := model.provider.rerank(ctx, query, texts) + if err != nil { + return nil, err + } + + rerankScores := map[T]float32{} + for idx, id := range ids { + rerankScores[id] = scores[idx] + } + return newIDScores(rerankScores, searchParams), nil +} + +func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) { + if len(model.queries) != int(searchParams.nq) { + return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", searchParams.nq, len(model.queries), model.queries) + } + outputs := newRerankOutputs(searchParams) + for idx, cols := range inputs.data { + idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols) + if err != nil { + return nil, err + } + appendResult(outputs, idScore.ids, idScore.scores) + } + return outputs, nil +} diff --git a/internal/util/function/rerank/model_function_test.go b/internal/util/function/rerank/model_function_test.go new file mode 100644 index 0000000000..4e3f709f8e --- /dev/null +++ b/internal/util/function/rerank/model_function_test.go @@ -0,0 +1,509 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package rerank + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "os" + "testing" + + "github.com/stretchr/testify/suite" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" + "github.com/milvus-io/milvus/internal/util/function" + "github.com/milvus-io/milvus/pkg/v2/util/paramtable" +) + +func TestRerankModel(t *testing.T) { + suite.Run(t, new(RerankModelSuite)) +} + +type RerankModelSuite struct { + suite.Suite +} + +func (s *RerankModelSuite) SetupTest() { + paramtable.Init() + paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string { + return map[string]string{} + } +} + +func (s *RerankModelSuite) TestNewProvider() { + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "unknown"}, + } + _, err := newProvider(params) + s.ErrorContains(err, "Unknow rerank provider") + } + + { + _, err := newProvider([]*commonpb.KeyValuePair{}) + s.ErrorContains(err, "Lost rerank params") + } + + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: "illegal"}, + } + _, err := newProvider(params) + s.ErrorContains(err, "is not a valid http/https link") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: "http://"}, + } + _, err := newProvider(params) + s.ErrorContains(err, "is not a valid http/https link") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + } + _, err := newProvider(params) + s.ErrorContains(err, "Rerank function lost params endpoint") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: "http://localhost:80"}, + {Key: maxBatchKeyName, Value: "-1"}, + } + _, err := newProvider(params) + s.ErrorContains(err, "Rerank function params max_batch must > 0") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: "http://localhost:80"}, + {Key: maxBatchKeyName, Value: "NotNum"}, + } + _, err := newProvider(params) + s.ErrorContains(err, "Rerank params error, maxBatch") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: "http://localhost:80"}, + } + provder, err := newProvider(params) + s.NoError(err) + s.Equal(provder.getURL(), "http://localhost:80/v2/rerank") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "tei"}, + {Key: function.EndpointParamKey, Value: "http://localhost:80"}, + } + provder, err := newProvider(params) + s.NoError(err) + s.Equal(provder.getURL(), "http://localhost:80/rerank") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "tei"}, + {Key: function.EndpointParamKey, Value: "http://localhost:80"}, + } + paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string { + key := "tei.enable" + return map[string]string{ + key: "false", + } + } + _, err := newProvider(params) + s.ErrorContains(err, "TEI rerank is disabled") + paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string { + return map[string]string{} + } + os.Setenv(function.EnableTeiEnvStr, "false") + _, err = newProvider(params) + s.ErrorContains(err, "TEI rerank is disabled") + os.Unsetenv(function.EnableTeiEnvStr) + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: "http://localhost:80"}, + } + paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string { + key := "vllm.enable" + return map[string]string{ + key: "false", + } + } + _, err := newProvider(params) + s.ErrorContains(err, "Vllm rerank is disabled") + paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string { + return map[string]string{} + } + os.Setenv(function.EnableVllmEnvStr, "false") + _, err = newProvider(params) + s.ErrorContains(err, "Vllm rerank is disabled") + os.Unsetenv(function.EnableVllmEnvStr) + } +} + +func (s *RerankModelSuite) TestCallVllm() { + { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`)) + })) + defer ts.Close() + + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + } + provder, err := newProvider(params) + s.NoError(err) + _, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"}) + s.ErrorContains(err, "Call Rerank service failed, 3 docs but got 2 scores") + } + { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{'object': 'error', 'message': 'The model vllm-test does not exist.', 'type': 'NotFoundError', 'param': None, 'code': 404}`)) + })) + defer ts.Close() + + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + } + provder, err := newProvider(params) + s.NoError(err) + _, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"}) + s.ErrorContains(err, "Call rerank model failed") + } + { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`Not json data`)) + })) + defer ts.Close() + + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + } + provder, err := newProvider(params) + s.NoError(err) + _, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"}) + s.ErrorContains(err, "Rerank error, parsing vllm response failed") + } + { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.0}, {"index": 2, "relevance_score": 0.2}, {"index": 1, "relevance_score": 0.1}]}`)) + })) + defer ts.Close() + + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + } + provder, err := newProvider(params) + s.NoError(err) + scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"}) + s.NoError(err) + s.Equal([]float32{0.0, 0.1, 0.2}, scores) + } +} + +func (s *RerankModelSuite) TestCallTEI() { + { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`[{"index":0,"score":0.0},{"index":1,"score":0.2}, {"index":2,"score":0.1}]`)) + })) + defer ts.Close() + + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "tei"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + } + provder, err := newProvider(params) + s.NoError(err) + scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"}) + s.NoError(err) + s.Equal([]float32{0.0, 0.2, 0.1}, scores) + } + { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`not json`)) + })) + defer ts.Close() + + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "tei"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + } + provder, err := newProvider(params) + s.NoError(err) + _, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"}) + s.ErrorContains(err, "Rerank error, parsing TEI response failed") + } +} + +func (s *RerankModelSuite) TestNewModelFunction() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + {FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64}, + }, + } + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{"text"}, + Params: []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "tei"}, + {Key: function.EndpointParamKey, Value: "http://localhost:80"}, + {Key: queryKeyName, Value: `["q1"]`}, + }, + } + { + functionSchema.InputFieldNames = []string{"text", "ts"} + _, err := newModelFunction(schema, functionSchema) + s.ErrorContains(err, "Rerank model only supports single input") + } + { + functionSchema.InputFieldNames = []string{"ts"} + _, err := newModelFunction(schema, functionSchema) + s.ErrorContains(err, "Rerank model only support varchar") + functionSchema.InputFieldNames = []string{"text"} + } + { + functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `NotJson`} + _, err := newModelFunction(schema, functionSchema) + s.ErrorContains(err, "Parse rerank params [queries] failed") + } + { + functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `[]`} + _, err := newModelFunction(schema, functionSchema) + s.ErrorContains(err, "Rerank function lost params queries") + } + { + functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `["test"]`} + _, err := newModelFunction(schema, functionSchema) + s.NoError(err) + } + { + schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true} + _, err := newModelFunction(schema, functionSchema) + s.NoError(err) + } + { + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: providerParamName, Value: `notExist`} + _, err := newModelFunction(schema, functionSchema) + s.ErrorContains(err, "Unknow rerank provider") + } + { + functionSchema.Params[0] = &commonpb.KeyValuePair{Key: "NotExist", Value: `notExist`} + _, err := newModelFunction(schema, functionSchema) + s.ErrorContains(err, "Lost rerank params") + } +} + +func (s *RerankModelSuite) TestRerankProcess() { + schema := &schemapb.CollectionSchema{ + Name: "test", + Fields: []*schemapb.FieldSchema{ + {FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true}, + {FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar}, + { + FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dim", Value: "4"}, + }, + }, + }, + } + { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`)) + })) + defer ts.Close() + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{"text"}, + Params: []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "tei"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + {Key: queryKeyName, Value: `["q1"]`}, + }, + } + + // empty + { + nq := int64(1) + f, err := newModelFunction(schema, functionSchema) + s.NoError(err) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs()) + ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + s.NoError(err) + s.Equal(int64(3), ret.searchResultData.TopK) + s.Equal([]int64{}, ret.searchResultData.Topks) + } + + // no input field exist + { + nq := int64(1) + f, err := newModelFunction(schema, functionSchema) + data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000) + s.NoError(err) + + _, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) + s.ErrorContains(err, "Search reaults mismatch rerank inputs") + } + } + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + req := map[string]any{} + body, _ := io.ReadAll(r.Body) + defer r.Body.Close() + json.Unmarshal(body, &req) + ret := map[string][]map[string]any{} + ret["results"] = []map[string]any{} + for i := range req["documents"].([]any) { + d := map[string]any{} + d["index"] = i + d["relevance_score"] = float32(i) / 10 + ret["results"] = append(ret["results"], d) + } + jsonData, _ := json.Marshal(ret) + w.Write(jsonData) + })) + defer ts.Close() + + // singleSearchResultData + functionSchema := &schemapb.FunctionSchema{ + Name: "test", + Type: schemapb.FunctionType_Rerank, + InputFieldNames: []string{"text"}, + Params: []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "vllm"}, + {Key: function.EndpointParamKey, Value: ts.URL}, + {Key: queryKeyName, Value: `["q1", "q2"]`}, + }, + } + + { + nq := int64(1) + { + f, err := newModelFunction(schema, functionSchema) + s.NoError(err) + data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) + _, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + s.ErrorContains(err, "nq must equal to queries size, but got nq [1], queries size [2]") + } + { + functionSchema.Params[2].Value = `["q1"]` + f, err := newModelFunction(schema, functionSchema) + s.NoError(err) + data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) + ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 0, -1, -1, 1, false, []string{"COSINE"}}, inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + } + + { + nq := int64(3) + functionSchema.Params[2].Value = `["q1", "q2", "q3"]` + f, err := newModelFunction(schema, functionSchema) + s.NoError(err) + data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs()) + ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + } + } + + // // multipSearchResultData + // has empty inputs + { + nq := int64(1) + functionSchema.Params[2].Value = `["q1"]` + f, err := newModelFunction(schema, functionSchema) + s.NoError(err) + data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + // empty + data2 := genSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) + ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + } + // nq = 1 + { + nq := int64(1) + f, err := newModelFunction(schema, functionSchema) + s.NoError(err) + // ts/id data: 0 - 9 + data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + // ts/id data: 0 - 3 + data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) + ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs) + s.NoError(err) + s.Equal([]int64{3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + } + // // nq = 3 + { + nq := int64(3) + functionSchema.Params[2].Value = `["q1", "q2", "q3"]` + f, err := newModelFunction(schema, functionSchema) + s.NoError(err) + data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101) + data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101) + inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs()) + ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs) + s.NoError(err) + s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks) + s.Equal(int64(3), ret.searchResultData.TopK) + } +} diff --git a/internal/util/function/rerank/rerank_base.go b/internal/util/function/rerank/rerank_base.go index 8b8254ca7e..ee6f0d6495 100644 --- a/internal/util/function/rerank/rerank_base.go +++ b/internal/util/function/rerank/rerank_base.go @@ -46,12 +46,18 @@ type RerankBase struct { pkType schemapb.DataType inputFieldIDs []int64 inputFieldNames []string + inputFieldTypes []schemapb.DataType // TODO: The parameter is passed to the reranker, and the reranker decides whether to implement the parameter searchParams *searchParams } -func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool, pkType schemapb.DataType) (*RerankBase, error) { +func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool) (*RerankBase, error) { + pkType, err := getPKType(coll) + if err != nil { + return nil, err + } + base := RerankBase{ inputFieldNames: funcSchema.InputFieldNames, rerankerName: rerankerName, @@ -82,6 +88,7 @@ func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.Functio return nil, fmt.Errorf("Function input field cannot be nullable: field %s", inputField.GetName()) } base.inputFieldIDs = append(base.inputFieldIDs, inputField.FieldID) + base.inputFieldTypes = append(base.inputFieldTypes, inputField.DataType) } return &base, nil } @@ -90,6 +97,10 @@ func (base *RerankBase) GetInputFieldNames() []string { return base.inputFieldNames } +func (base *RerankBase) GetInputFieldTypes() []schemapb.DataType { + return base.inputFieldTypes +} + func (base *RerankBase) GetInputFieldIDs() []int64 { return base.inputFieldIDs } diff --git a/internal/util/function/rerank/util.go b/internal/util/function/rerank/util.go index 0ec918dfed..6fe8b1d6a1 100644 --- a/internal/util/function/rerank/util.go +++ b/internal/util/function/rerank/util.go @@ -236,7 +236,7 @@ func getField(inputField *schemapb.FieldData, start int64, size int64) (any, err return inputField.GetScalars().GetBoolData().Data[start : start+size], nil } return []bool{}, nil - case schemapb.DataType_String: + case schemapb.DataType_String, schemapb.DataType_VarChar: if inputField.GetScalars() != nil && inputField.GetScalars().GetStringData() != nil { return inputField.GetScalars().GetStringData().Data[start : start+size], nil } @@ -285,3 +285,17 @@ func maxMerge[T PKType](cols []*columns) map[T]float32 { } return srcScores } + +func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error) { + pkType := schemapb.DataType_None + for _, field := range collSchema.Fields { + if field.IsPrimaryKey { + pkType = field.DataType + } + } + + if pkType == schemapb.DataType_None { + return pkType, fmt.Errorf("Collection %s can not found pk field", collSchema.Name) + } + return pkType, nil +} diff --git a/internal/util/function/tei_embedding_provider.go b/internal/util/function/tei_embedding_provider.go index d74d4eeaa7..b5703419d0 100644 --- a/internal/util/function/tei_embedding_provider.go +++ b/internal/util/function/tei_embedding_provider.go @@ -47,7 +47,7 @@ type TeiEmbeddingProvider struct { } func createTEIEmbeddingClient(apiKey string, endpoint string) (*tei.TEIEmbedding, error) { - enable := os.Getenv(enableTeiEnvStr) + enable := os.Getenv(EnableTeiEnvStr) if strings.ToLower(enable) == "false" { return nil, errors.New("TEI model serving is not enabled") } @@ -69,7 +69,7 @@ func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema * for _, param := range functionSchema.Params { switch strings.ToLower(param.Key) { - case endpointParamKey: + case EndpointParamKey: endpoint = param.Value case ingestionPromptParamKey: ingestionPrompt = param.Value diff --git a/internal/util/function/tei_embedding_provider_test.go b/internal/util/function/tei_embedding_provider_test.go index e3f7efa44c..e93cfa406a 100644 --- a/internal/util/function/tei_embedding_provider_test.go +++ b/internal/util/function/tei_embedding_provider_test.go @@ -70,7 +70,7 @@ func createTEIProvider(url string, schema *schemapb.FieldSchema, providerName st OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: credentialParamKey, Value: "mock"}, - {Key: endpointParamKey, Value: url}, + {Key: EndpointParamKey, Value: url}, {Key: ingestionPromptParamKey, Value: "doc:"}, {Key: searchPromptParamKey, Value: "query:"}, }, @@ -154,8 +154,8 @@ func (s *TEITextEmbeddingProviderSuite) TestCreateTEIEmbeddingClient() { _, err = createTEIEmbeddingClient("", "http://mymock.com") s.NoError(err) - os.Setenv(enableTeiEnvStr, "false") - defer os.Unsetenv(enableTeiEnvStr) + os.Setenv(EnableTeiEnvStr, "false") + defer os.Unsetenv(EnableTeiEnvStr) _, err = createTEIEmbeddingClient("", "http://mymock.com") s.Error(err) } @@ -170,7 +170,7 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() { OutputFieldIds: []int64{102}, Params: []*commonpb.KeyValuePair{ {Key: credentialParamKey, Value: "mock"}, - {Key: endpointParamKey, Value: "http://mymock.com"}, + {Key: EndpointParamKey, Value: "http://mymock.com"}, }, } provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"})) diff --git a/pkg/util/paramtable/function_param.go b/pkg/util/paramtable/function_param.go index dad1f29662..7f381d4648 100644 --- a/pkg/util/paramtable/function_param.go +++ b/pkg/util/paramtable/function_param.go @@ -22,6 +22,7 @@ import ( type functionConfig struct { TextEmbeddingProviders ParamGroup `refreshable:"true"` + RerankModelProviders ParamGroup `refreshable:"true"` } func (p *functionConfig) init(base *BaseTable) { @@ -73,6 +74,23 @@ func (p *functionConfig) init(base *BaseTable) { }, } p.TextEmbeddingProviders.Init(base.mgr) + + p.RerankModelProviders = ParamGroup{ + KeyPrefix: "function.rerank.model.providers.", + Version: "2.6.0", + Export: true, + DocFunc: func(key string) string { + switch key { + case "tei.enable": + return "Whether to enable TEI rerank service" + case "vllm.enable": + return "Whether to enable vllm rerank service" + default: + return "" + } + }, + } + p.RerankModelProviders.Init(base.mgr) } const ( @@ -92,3 +110,17 @@ func (p *functionConfig) GetTextEmbeddingProviderConfig(providerName string) map } return matchedParam } + +func (p *functionConfig) GetRerankModelProviders(providerName string) map[string]string { + matchedParam := make(map[string]string) + + params := p.RerankModelProviders.GetValue() + prefix := providerName + "." + + for k, v := range params { + if strings.HasPrefix(k, prefix) { + matchedParam[strings.TrimPrefix(k, prefix)] = v + } + } + return matchedParam +} diff --git a/tests/python_client/milvus_client/test_milvus_client_search.py b/tests/python_client/milvus_client/test_milvus_client_search.py index 59aabb297a..02a07d9c0f 100644 --- a/tests/python_client/milvus_client/test_milvus_client_search.py +++ b/tests/python_client/milvus_client/test_milvus_client_search.py @@ -925,7 +925,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Decay function only supoorts single input, but gets [[reranker_field id]] input"} + ct.err_msg: f"Decay function only supports single input, but gets [[reranker_field id]] input"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -1053,7 +1053,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Unsupported rerank function: [1] , list of supported [decay]"} + ct.err_msg: f"Unsupported rerank function: [1]"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -1098,7 +1098,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base): ) vectors_to_search = rng.random((1, dim)) error = {ct.err_code: 65535, - ct.err_msg: f"Unsupported rerank function: [{not_supported_reranker}] , list of supported [decay]"} + ct.err_msg: f"Unsupported rerank function: [{not_supported_reranker}]"} self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn, check_task=CheckTasks.err_res, check_items=error) @@ -3997,4 +3997,4 @@ class TestMilvusClientSearchRerankValid(TestMilvusClientV2Base): "nq": len(vectors_to_search), "pk_name": default_primary_key_field_name, "limit": default_limit} - ) \ No newline at end of file + )