feat: Support vllm and tei rerank (#41947)

https://github.com/milvus-io/milvus/issues/35856

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
junjiejiangjjj 2025-05-28 19:18:28 +08:00 committed by GitHub
parent 14563ad2b3
commit 4202c775ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1013 additions and 66 deletions

View File

@ -1327,3 +1327,10 @@ function:
voyageai:
credential: # The name in the crendential configuration item
url: # Your voyageai embedding url, Default is the official embedding url
rerank:
model:
providers:
tei:
enable: true # Whether to enable TEI rerank service
vllm:
enable: true # Whether to enable vllm rerank service

View File

@ -646,6 +646,17 @@ func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error {
var err error
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
)
return t.functionScore.Process(ctx, params, results)
}
// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
// At this time, outputFields and rerank input_fields will be recalled.
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
@ -682,12 +693,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
for i := 0; i < len(multipleMilvusResults); i++ {
multipleMilvusResults[i].Results.FieldsData = fields[i]
}
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
)
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
return err
}
if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil {
@ -696,11 +702,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
t.result.Results.FieldsData = fields[0]
}
} else {
params := rerank.NewSearchParams(
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
)
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
return err
}
}
@ -823,11 +825,15 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR
}
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
// rank only returns id and score
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
return err
{
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
defer sp.End()
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
// rank only returns id and score
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
return err
}
}
if !t.needRequery {
fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids})

View File

@ -109,16 +109,17 @@ const (
siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY"
)
// TEI
// TEI and vllm
const (
ingestionPromptParamKey string = "ingestion_prompt"
searchPromptParamKey string = "search_prompt"
maxClientBatchSizeParamKey string = "max_client_batch_size"
truncationDirectionParamKey string = "truncation_direction"
endpointParamKey string = "endpoint"
EndpointParamKey string = "endpoint"
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
EnableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
EnableVllmEnvStr string = "MILVUSAI_ENABLE_VLLM"
)
func parseAKAndURL(credentials *credentials.Credentials, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) {

View File

@ -140,7 +140,7 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}

View File

@ -106,7 +106,7 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}

View File

@ -149,7 +149,7 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3)
if err != nil {
return nil, err
}

View File

@ -218,8 +218,8 @@ func TestEmbeddingFailed(t *testing.T) {
func TestTimeout(t *testing.T) {
var st int32 = 0
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// (Timeout 1s + Wait 1s) * Retry 3
time.Sleep(6 * time.Second)
// (Timeout 1s + 2s + 4s + Wait 1s * 3)
time.Sleep(11 * time.Second)
atomic.AddInt32(&st, 1)
w.WriteHeader(http.StatusUnauthorized)
}))

View File

@ -121,7 +121,7 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}

View File

@ -93,7 +93,7 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect
if c.apiKey != "" {
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey)
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}

View File

@ -21,6 +21,7 @@ import (
"context"
"fmt"
"io"
"math/rand"
"net/http"
"time"
)
@ -45,7 +46,7 @@ func send(req *http.Request) ([]byte, error) {
return body, nil
}
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) {
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int) ([]byte, error) {
var err error
var body []byte
for i := 0; i < maxRetries; i++ {
@ -60,7 +61,9 @@ func RetrySend(ctx context.Context, data []byte, httpMethod string, url string,
if err == nil {
return body, nil
}
time.Sleep(time.Duration(retryDelay) * time.Second)
backoffDelay := 1 << uint(i) * time.Second
jitter := time.Duration(rand.Int63n(int64(backoffDelay / 4)))
time.Sleep(backoffDelay + jitter)
}
return nil, err
}

View File

@ -148,7 +148,7 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", token),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}

View File

@ -138,7 +138,7 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
if err != nil {
return nil, err
}

View File

@ -55,34 +55,17 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
}
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
pkType := schemapb.DataType_None
for _, field := range collSchema.Fields {
if field.IsPrimaryKey {
pkType = field.DataType
}
}
if pkType == schemapb.DataType_None {
return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
}
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType)
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
if err != nil {
return nil, err
}
if len(base.GetInputFieldNames()) != 1 {
return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames())
return nil, fmt.Errorf("Decay function only supports single input, but gets [%s] input", base.GetInputFieldNames())
}
var inputType schemapb.DataType
for _, field := range collSchema.Fields {
if field.Name == base.GetInputFieldNames()[0] {
inputType = field.DataType
}
}
if pkType == schemapb.DataType_Int64 {
inputType := base.GetInputFieldTypes()[0]
if base.pkType == schemapb.DataType_Int64 {
switch inputType {
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
return newFunction[int64, int32](base, funcSchema)
@ -160,7 +143,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
}
if decayFunc.decay <= 0 || decayFunc.decay >= 1 {
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.offset)
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.decay)
}
switch decayFunc.functionName {

View File

@ -90,7 +90,7 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() {
{
functionSchema.InputFieldNames = []string{"ts", "pk"}
_, err := newDecayFunction(schema, functionSchema)
s.ErrorContains(err, "Decay function only supoorts single input, but gets")
s.ErrorContains(err, "Decay function only supports single input, but gets")
}
{

View File

@ -32,6 +32,7 @@ import (
const (
decayFunctionName string = "decay"
modelFunctionName string = "model"
)
type SearchParams struct {
@ -92,8 +93,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
switch rerankerName {
case decayFunctionName:
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
case modelFunctionName:
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
default:
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName)
}
if newRerankErr != nil {

View File

@ -0,0 +1,378 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"sort"
"strconv"
"strings"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/function/models/utils"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
const (
providerParamName string = "provider"
vllmProviderName string = "vllm"
teiProviderName string = "tei"
queryKeyName string = "queries"
maxBatchKeyName string = "max_batch"
)
type modelProvider interface {
rerank(context.Context, string, []string) ([]float32, error)
getURL() string
}
type baseModel struct {
url string
maxBatch int
queryKey string
docKey string
parseScores func([]byte) ([]float32, error)
}
func (base *baseModel) getURL() string {
return base.url
}
func (base *baseModel) rerank(ctx context.Context, query string, docs []string) ([]float32, error) {
requestBodies, err := genRerankRequestBody(query, docs, base.maxBatch, base.queryKey, base.docKey)
if err != nil {
return nil, err
}
scores := []float32{}
for _, requestBody := range requestBodies {
rerankResp, err := base.callService(ctx, requestBody, 30)
if err != nil {
return nil, fmt.Errorf("Call rerank model failed: %v\n", err)
}
scores = append(scores, rerankResp...)
}
if len(scores) != len(docs) {
return nil, fmt.Errorf("Call Rerank service failed, %d docs but got %d scores", len(docs), len(scores))
}
return scores, nil
}
func (base *baseModel) callService(ctx context.Context, requestBody []byte, timeoutSec int64) ([]float32, error) {
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second)
defer cancel()
headers := map[string]string{
"Content-Type": "application/json",
}
body, err := utils.RetrySend(ctx, requestBody, http.MethodPost, base.url, headers, 3)
if err != nil {
return nil, err
}
return base.parseScores(body)
}
type vllmRerankRequest struct {
Query string `json:"query"`
Documents []string `json:"documents"`
}
type vllmRerankResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Usage vllmUsage `json:"usage"`
Results []vllmResult `json:"results"`
}
type vllmUsage struct {
TotalTokens int `json:"total_tokens"`
}
type vllmResult struct {
Index int `json:"index"`
Document vllmDocument `json:"document"`
RelevanceScore float32 `json:"relevance_score"`
}
type vllmDocument struct {
Text string `json:"text"`
}
type vllmProvider struct {
baseModel
}
func newVllmProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
if !isEnable(conf, function.EnableVllmEnvStr) {
return nil, fmt.Errorf("Vllm rerank is disabled")
}
endpoint, maxBatch, err := parseParams(params)
if err != nil {
return nil, err
}
base, _ := url.Parse(endpoint)
base.Path = "/v2/rerank"
model := baseModel{
url: base.String(),
maxBatch: maxBatch,
queryKey: "query",
docKey: "documents",
parseScores: func(body []byte) ([]float32, error) {
var rerankResp vllmRerankResponse
if err := json.Unmarshal(body, &rerankResp); err != nil {
return nil, fmt.Errorf("Rerank error, parsing vllm response failed: %v", err)
}
sort.Slice(rerankResp.Results, func(i, j int) bool {
return rerankResp.Results[i].Index < rerankResp.Results[j].Index
})
scores := make([]float32, 0, len(rerankResp.Results))
for _, result := range rerankResp.Results {
scores = append(scores, result.RelevanceScore)
}
return scores, nil
},
}
return &vllmProvider{baseModel: model}, nil
}
type teiProvider struct {
baseModel
}
type TEIResponse struct {
Index int `json:"index"`
Score float32 `json:"score"`
}
func newTeiProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
if !isEnable(conf, function.EnableTeiEnvStr) {
return nil, fmt.Errorf("TEI rerank is disabled")
}
endpoint, maxBatch, err := parseParams(params)
if err != nil {
return nil, err
}
base, _ := url.Parse(endpoint)
base.Path = "/rerank"
model := baseModel{
url: base.String(),
maxBatch: maxBatch,
queryKey: "query",
docKey: "texts",
parseScores: func(body []byte) ([]float32, error) {
var results []TEIResponse
if err := json.Unmarshal(body, &results); err != nil {
return nil, fmt.Errorf("Rerank error, parsing TEI response failed: %v", err)
}
sort.Slice(results, func(i, j int) bool {
return results[i].Index < results[j].Index
})
scores := make([]float32, len(results))
for i, result := range results {
scores[i] = result.Score
}
return scores, nil
},
}
return &teiProvider{baseModel: model}, nil
}
func isEnable(conf map[string]string, envKey string) bool {
// milvus.yaml > env
value, exists := conf["enable"]
if exists {
return strings.ToLower(value) == "true"
} else {
return !(strings.ToLower(os.Getenv(envKey)) == "false")
}
}
func parseParams(params []*commonpb.KeyValuePair) (string, int, error) {
endpoint := ""
maxBatch := 32
for _, param := range params {
switch strings.ToLower(param.Key) {
case function.EndpointParamKey:
base, err := url.Parse(param.Value)
if err != nil {
return "", 0, err
}
if base.Scheme != "http" && base.Scheme != "https" {
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
}
if base.Host == "" {
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
}
endpoint = base.String()
case maxBatchKeyName:
if batch, err := strconv.ParseInt(param.Value, 10, 64); err != nil {
return "", 0, fmt.Errorf("Rerank params error, maxBatch: %s is not a number", param.Value)
} else {
maxBatch = int(batch)
}
}
}
if endpoint == "" {
return "", 0, fmt.Errorf("Rerank function lost params endpoint")
}
if maxBatch <= 0 {
return "", 0, fmt.Errorf("Rerank function params max_batch must > 0, but got %d", maxBatch)
}
return endpoint, maxBatch, nil
}
func genRerankRequestBody(query string, documents []string, maxSize int, queryKey string, docKey string) ([][]byte, error) {
requestBodies := [][]byte{}
for i := 0; i < len(documents); i += maxSize {
end := i + maxSize
if end > len(documents) {
end = len(documents)
}
requestBody := map[string]interface{}{
queryKey: query,
docKey: documents[i:end],
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("Create model rerank request failed, err: %s", err)
}
requestBodies = append(requestBodies, jsonData)
}
return requestBodies, nil
}
func newProvider(params []*commonpb.KeyValuePair) (modelProvider, error) {
for _, param := range params {
if strings.ToLower(param.Key) == providerParamName {
provider := strings.ToLower(param.Value)
conf := paramtable.Get().FunctionCfg.GetRerankModelProviders(provider)
switch provider {
case vllmProviderName:
return newVllmProvider(params, conf)
case teiProviderName:
return newTeiProvider(params, conf)
default:
return nil, fmt.Errorf("Unknow rerank provider:%s", param.Value)
}
}
}
return nil, fmt.Errorf("Lost rerank params:%s ", providerParamName)
}
type ModelFunction[T PKType] struct {
RerankBase
provider modelProvider
queries []string
}
func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
if err != nil {
return nil, err
}
if len(base.GetInputFieldNames()) != 1 {
return nil, fmt.Errorf("Rerank model only supports single input, but gets [%s] input", base.GetInputFieldNames())
}
if base.GetInputFieldTypes()[0] != schemapb.DataType_VarChar {
return nil, fmt.Errorf("Rerank model only support varchar, bug got [%s]", base.GetInputFieldTypes()[0].String())
}
provider, err := newProvider(funcSchema.Params)
if err != nil {
return nil, err
}
queries := []string{}
for _, param := range funcSchema.Params {
if param.Key == queryKeyName {
if err := json.Unmarshal([]byte(param.Value), &queries); err != nil {
return nil, fmt.Errorf("Parse rerank params [queries] failed, err: %v", err)
}
}
}
if len(queries) == 0 {
return nil, fmt.Errorf("Rerank function lost params queries")
}
if base.pkType == schemapb.DataType_Int64 {
return &ModelFunction[int64]{RerankBase: *base, provider: provider, queries: queries}, nil
} else {
return &ModelFunction[string]{RerankBase: *base, provider: provider, queries: queries}, nil
}
}
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) {
uniqueData := make(map[T]string)
for _, col := range cols {
texts := col.data[0].([]string)
ids := col.ids.([]T)
for idx, id := range ids {
if _, ok := uniqueData[id]; !ok {
uniqueData[id] = texts[idx]
}
}
}
ids := make([]T, 0, len(uniqueData))
texts := make([]string, 0, len(uniqueData))
for id, text := range uniqueData {
ids = append(ids, id)
texts = append(texts, text)
}
scores, err := model.provider.rerank(ctx, query, texts)
if err != nil {
return nil, err
}
rerankScores := map[T]float32{}
for idx, id := range ids {
rerankScores[id] = scores[idx]
}
return newIDScores(rerankScores, searchParams), nil
}
func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
if len(model.queries) != int(searchParams.nq) {
return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", searchParams.nq, len(model.queries), model.queries)
}
outputs := newRerankOutputs(searchParams)
for idx, cols := range inputs.data {
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols)
if err != nil {
return nil, err
}
appendResult(outputs, idScore.ids, idScore.scores)
}
return outputs, nil
}

View File

@ -0,0 +1,509 @@
/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/stretchr/testify/suite"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
func TestRerankModel(t *testing.T) {
suite.Run(t, new(RerankModelSuite))
}
type RerankModelSuite struct {
suite.Suite
}
func (s *RerankModelSuite) SetupTest() {
paramtable.Init()
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
return map[string]string{}
}
}
func (s *RerankModelSuite) TestNewProvider() {
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "unknown"},
}
_, err := newProvider(params)
s.ErrorContains(err, "Unknow rerank provider")
}
{
_, err := newProvider([]*commonpb.KeyValuePair{})
s.ErrorContains(err, "Lost rerank params")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: "illegal"},
}
_, err := newProvider(params)
s.ErrorContains(err, "is not a valid http/https link")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: "http://"},
}
_, err := newProvider(params)
s.ErrorContains(err, "is not a valid http/https link")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
}
_, err := newProvider(params)
s.ErrorContains(err, "Rerank function lost params endpoint")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
{Key: maxBatchKeyName, Value: "-1"},
}
_, err := newProvider(params)
s.ErrorContains(err, "Rerank function params max_batch must > 0")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
{Key: maxBatchKeyName, Value: "NotNum"},
}
_, err := newProvider(params)
s.ErrorContains(err, "Rerank params error, maxBatch")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
}
provder, err := newProvider(params)
s.NoError(err)
s.Equal(provder.getURL(), "http://localhost:80/v2/rerank")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "tei"},
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
}
provder, err := newProvider(params)
s.NoError(err)
s.Equal(provder.getURL(), "http://localhost:80/rerank")
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "tei"},
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
}
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
key := "tei.enable"
return map[string]string{
key: "false",
}
}
_, err := newProvider(params)
s.ErrorContains(err, "TEI rerank is disabled")
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
return map[string]string{}
}
os.Setenv(function.EnableTeiEnvStr, "false")
_, err = newProvider(params)
s.ErrorContains(err, "TEI rerank is disabled")
os.Unsetenv(function.EnableTeiEnvStr)
}
{
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
}
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
key := "vllm.enable"
return map[string]string{
key: "false",
}
}
_, err := newProvider(params)
s.ErrorContains(err, "Vllm rerank is disabled")
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
return map[string]string{}
}
os.Setenv(function.EnableVllmEnvStr, "false")
_, err = newProvider(params)
s.ErrorContains(err, "Vllm rerank is disabled")
os.Unsetenv(function.EnableVllmEnvStr)
}
}
func (s *RerankModelSuite) TestCallVllm() {
{
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`))
}))
defer ts.Close()
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: ts.URL},
}
provder, err := newProvider(params)
s.NoError(err)
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
s.ErrorContains(err, "Call Rerank service failed, 3 docs but got 2 scores")
}
{
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{'object': 'error', 'message': 'The model vllm-test does not exist.', 'type': 'NotFoundError', 'param': None, 'code': 404}`))
}))
defer ts.Close()
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: ts.URL},
}
provder, err := newProvider(params)
s.NoError(err)
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
s.ErrorContains(err, "Call rerank model failed")
}
{
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`Not json data`))
}))
defer ts.Close()
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: ts.URL},
}
provder, err := newProvider(params)
s.NoError(err)
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
s.ErrorContains(err, "Rerank error, parsing vllm response failed")
}
{
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.0}, {"index": 2, "relevance_score": 0.2}, {"index": 1, "relevance_score": 0.1}]}`))
}))
defer ts.Close()
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: ts.URL},
}
provder, err := newProvider(params)
s.NoError(err)
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
s.NoError(err)
s.Equal([]float32{0.0, 0.1, 0.2}, scores)
}
}
func (s *RerankModelSuite) TestCallTEI() {
{
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`[{"index":0,"score":0.0},{"index":1,"score":0.2}, {"index":2,"score":0.1}]`))
}))
defer ts.Close()
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "tei"},
{Key: function.EndpointParamKey, Value: ts.URL},
}
provder, err := newProvider(params)
s.NoError(err)
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
s.NoError(err)
s.Equal([]float32{0.0, 0.2, 0.1}, scores)
}
{
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`not json`))
}))
defer ts.Close()
params := []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "tei"},
{Key: function.EndpointParamKey, Value: ts.URL},
}
provder, err := newProvider(params)
s.NoError(err)
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
s.ErrorContains(err, "Rerank error, parsing TEI response failed")
}
}
func (s *RerankModelSuite) TestNewModelFunction() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
},
}
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"text"},
Params: []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "tei"},
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
{Key: queryKeyName, Value: `["q1"]`},
},
}
{
functionSchema.InputFieldNames = []string{"text", "ts"}
_, err := newModelFunction(schema, functionSchema)
s.ErrorContains(err, "Rerank model only supports single input")
}
{
functionSchema.InputFieldNames = []string{"ts"}
_, err := newModelFunction(schema, functionSchema)
s.ErrorContains(err, "Rerank model only support varchar")
functionSchema.InputFieldNames = []string{"text"}
}
{
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `NotJson`}
_, err := newModelFunction(schema, functionSchema)
s.ErrorContains(err, "Parse rerank params [queries] failed")
}
{
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `[]`}
_, err := newModelFunction(schema, functionSchema)
s.ErrorContains(err, "Rerank function lost params queries")
}
{
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `["test"]`}
_, err := newModelFunction(schema, functionSchema)
s.NoError(err)
}
{
schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}
_, err := newModelFunction(schema, functionSchema)
s.NoError(err)
}
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: providerParamName, Value: `notExist`}
_, err := newModelFunction(schema, functionSchema)
s.ErrorContains(err, "Unknow rerank provider")
}
{
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: "NotExist", Value: `notExist`}
_, err := newModelFunction(schema, functionSchema)
s.ErrorContains(err, "Lost rerank params")
}
}
func (s *RerankModelSuite) TestRerankProcess() {
schema := &schemapb.CollectionSchema{
Name: "test",
Fields: []*schemapb.FieldSchema{
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
{
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
TypeParams: []*commonpb.KeyValuePair{
{Key: "dim", Value: "4"},
},
},
},
}
{
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`))
}))
defer ts.Close()
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"text"},
Params: []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "tei"},
{Key: function.EndpointParamKey, Value: ts.URL},
{Key: queryKeyName, Value: `["q1"]`},
},
}
// empty
{
nq := int64(1)
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
s.NoError(err)
s.Equal(int64(3), ret.searchResultData.TopK)
s.Equal([]int64{}, ret.searchResultData.Topks)
}
// no input field exist
{
nq := int64(1)
f, err := newModelFunction(schema, functionSchema)
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
s.NoError(err)
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
s.ErrorContains(err, "Search reaults mismatch rerank inputs")
}
}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
req := map[string]any{}
body, _ := io.ReadAll(r.Body)
defer r.Body.Close()
json.Unmarshal(body, &req)
ret := map[string][]map[string]any{}
ret["results"] = []map[string]any{}
for i := range req["documents"].([]any) {
d := map[string]any{}
d["index"] = i
d["relevance_score"] = float32(i) / 10
ret["results"] = append(ret["results"], d)
}
jsonData, _ := json.Marshal(ret)
w.Write(jsonData)
}))
defer ts.Close()
// singleSearchResultData
functionSchema := &schemapb.FunctionSchema{
Name: "test",
Type: schemapb.FunctionType_Rerank,
InputFieldNames: []string{"text"},
Params: []*commonpb.KeyValuePair{
{Key: providerParamName, Value: "vllm"},
{Key: function.EndpointParamKey, Value: ts.URL},
{Key: queryKeyName, Value: `["q1", "q2"]`},
},
}
{
nq := int64(1)
{
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
_, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
s.ErrorContains(err, "nq must equal to queries size, but got nq [1], queries size [2]")
}
{
functionSchema.Params[2].Value = `["q1"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 0, -1, -1, 1, false, []string{"COSINE"}}, inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
}
{
nq := int64(3)
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
}
}
// // multipSearchResultData
// has empty inputs
{
nq := int64(1)
functionSchema.Params[2].Value = `["q1"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
// empty
data2 := genSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
}
// nq = 1
{
nq := int64(1)
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
// ts/id data: 0 - 9
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
// ts/id data: 0 - 3
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
s.NoError(err)
s.Equal([]int64{3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
}
// // nq = 3
{
nq := int64(3)
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
f, err := newModelFunction(schema, functionSchema)
s.NoError(err)
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
s.NoError(err)
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
s.Equal(int64(3), ret.searchResultData.TopK)
}
}

View File

@ -46,12 +46,18 @@ type RerankBase struct {
pkType schemapb.DataType
inputFieldIDs []int64
inputFieldNames []string
inputFieldTypes []schemapb.DataType
// TODO: The parameter is passed to the reranker, and the reranker decides whether to implement the parameter
searchParams *searchParams
}
func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool, pkType schemapb.DataType) (*RerankBase, error) {
func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool) (*RerankBase, error) {
pkType, err := getPKType(coll)
if err != nil {
return nil, err
}
base := RerankBase{
inputFieldNames: funcSchema.InputFieldNames,
rerankerName: rerankerName,
@ -82,6 +88,7 @@ func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.Functio
return nil, fmt.Errorf("Function input field cannot be nullable: field %s", inputField.GetName())
}
base.inputFieldIDs = append(base.inputFieldIDs, inputField.FieldID)
base.inputFieldTypes = append(base.inputFieldTypes, inputField.DataType)
}
return &base, nil
}
@ -90,6 +97,10 @@ func (base *RerankBase) GetInputFieldNames() []string {
return base.inputFieldNames
}
func (base *RerankBase) GetInputFieldTypes() []schemapb.DataType {
return base.inputFieldTypes
}
func (base *RerankBase) GetInputFieldIDs() []int64 {
return base.inputFieldIDs
}

View File

@ -236,7 +236,7 @@ func getField(inputField *schemapb.FieldData, start int64, size int64) (any, err
return inputField.GetScalars().GetBoolData().Data[start : start+size], nil
}
return []bool{}, nil
case schemapb.DataType_String:
case schemapb.DataType_String, schemapb.DataType_VarChar:
if inputField.GetScalars() != nil && inputField.GetScalars().GetStringData() != nil {
return inputField.GetScalars().GetStringData().Data[start : start+size], nil
}
@ -285,3 +285,17 @@ func maxMerge[T PKType](cols []*columns) map[T]float32 {
}
return srcScores
}
func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error) {
pkType := schemapb.DataType_None
for _, field := range collSchema.Fields {
if field.IsPrimaryKey {
pkType = field.DataType
}
}
if pkType == schemapb.DataType_None {
return pkType, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
}
return pkType, nil
}

View File

@ -47,7 +47,7 @@ type TeiEmbeddingProvider struct {
}
func createTEIEmbeddingClient(apiKey string, endpoint string) (*tei.TEIEmbedding, error) {
enable := os.Getenv(enableTeiEnvStr)
enable := os.Getenv(EnableTeiEnvStr)
if strings.ToLower(enable) == "false" {
return nil, errors.New("TEI model serving is not enabled")
}
@ -69,7 +69,7 @@ func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *
for _, param := range functionSchema.Params {
switch strings.ToLower(param.Key) {
case endpointParamKey:
case EndpointParamKey:
endpoint = param.Value
case ingestionPromptParamKey:
ingestionPrompt = param.Value

View File

@ -70,7 +70,7 @@ func createTEIProvider(url string, schema *schemapb.FieldSchema, providerName st
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: credentialParamKey, Value: "mock"},
{Key: endpointParamKey, Value: url},
{Key: EndpointParamKey, Value: url},
{Key: ingestionPromptParamKey, Value: "doc:"},
{Key: searchPromptParamKey, Value: "query:"},
},
@ -154,8 +154,8 @@ func (s *TEITextEmbeddingProviderSuite) TestCreateTEIEmbeddingClient() {
_, err = createTEIEmbeddingClient("", "http://mymock.com")
s.NoError(err)
os.Setenv(enableTeiEnvStr, "false")
defer os.Unsetenv(enableTeiEnvStr)
os.Setenv(EnableTeiEnvStr, "false")
defer os.Unsetenv(EnableTeiEnvStr)
_, err = createTEIEmbeddingClient("", "http://mymock.com")
s.Error(err)
}
@ -170,7 +170,7 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() {
OutputFieldIds: []int64{102},
Params: []*commonpb.KeyValuePair{
{Key: credentialParamKey, Value: "mock"},
{Key: endpointParamKey, Value: "http://mymock.com"},
{Key: EndpointParamKey, Value: "http://mymock.com"},
},
}
provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"}))

View File

@ -22,6 +22,7 @@ import (
type functionConfig struct {
TextEmbeddingProviders ParamGroup `refreshable:"true"`
RerankModelProviders ParamGroup `refreshable:"true"`
}
func (p *functionConfig) init(base *BaseTable) {
@ -73,6 +74,23 @@ func (p *functionConfig) init(base *BaseTable) {
},
}
p.TextEmbeddingProviders.Init(base.mgr)
p.RerankModelProviders = ParamGroup{
KeyPrefix: "function.rerank.model.providers.",
Version: "2.6.0",
Export: true,
DocFunc: func(key string) string {
switch key {
case "tei.enable":
return "Whether to enable TEI rerank service"
case "vllm.enable":
return "Whether to enable vllm rerank service"
default:
return ""
}
},
}
p.RerankModelProviders.Init(base.mgr)
}
const (
@ -92,3 +110,17 @@ func (p *functionConfig) GetTextEmbeddingProviderConfig(providerName string) map
}
return matchedParam
}
func (p *functionConfig) GetRerankModelProviders(providerName string) map[string]string {
matchedParam := make(map[string]string)
params := p.RerankModelProviders.GetValue()
prefix := providerName + "."
for k, v := range params {
if strings.HasPrefix(k, prefix) {
matchedParam[strings.TrimPrefix(k, prefix)] = v
}
}
return matchedParam
}

View File

@ -925,7 +925,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
)
vectors_to_search = rng.random((1, dim))
error = {ct.err_code: 65535,
ct.err_msg: f"Decay function only supoorts single input, but gets [[reranker_field id]] input"}
ct.err_msg: f"Decay function only supports single input, but gets [[reranker_field id]] input"}
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
check_task=CheckTasks.err_res, check_items=error)
@ -1053,7 +1053,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
)
vectors_to_search = rng.random((1, dim))
error = {ct.err_code: 65535,
ct.err_msg: f"Unsupported rerank function: [1] , list of supported [decay]"}
ct.err_msg: f"Unsupported rerank function: [1]"}
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
check_task=CheckTasks.err_res, check_items=error)
@ -1098,7 +1098,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
)
vectors_to_search = rng.random((1, dim))
error = {ct.err_code: 65535,
ct.err_msg: f"Unsupported rerank function: [{not_supported_reranker}] , list of supported [decay]"}
ct.err_msg: f"Unsupported rerank function: [{not_supported_reranker}]"}
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
check_task=CheckTasks.err_res, check_items=error)
@ -3997,4 +3997,4 @@ class TestMilvusClientSearchRerankValid(TestMilvusClientV2Base):
"nq": len(vectors_to_search),
"pk_name": default_primary_key_field_name,
"limit": default_limit}
)
)