milvus/internal/util/function/rerank/model_function.go
junjiejiangjjj 4202c775ba
feat: Support vllm and tei rerank (#41947)
https://github.com/milvus-io/milvus/issues/35856

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
2025-05-28 19:18:28 +08:00

379 lines
11 KiB
Go

/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package rerank
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"os"
"sort"
"strconv"
"strings"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/function"
"github.com/milvus-io/milvus/internal/util/function/models/utils"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
)
const (
providerParamName string = "provider"
vllmProviderName string = "vllm"
teiProviderName string = "tei"
queryKeyName string = "queries"
maxBatchKeyName string = "max_batch"
)
type modelProvider interface {
rerank(context.Context, string, []string) ([]float32, error)
getURL() string
}
type baseModel struct {
url string
maxBatch int
queryKey string
docKey string
parseScores func([]byte) ([]float32, error)
}
func (base *baseModel) getURL() string {
return base.url
}
func (base *baseModel) rerank(ctx context.Context, query string, docs []string) ([]float32, error) {
requestBodies, err := genRerankRequestBody(query, docs, base.maxBatch, base.queryKey, base.docKey)
if err != nil {
return nil, err
}
scores := []float32{}
for _, requestBody := range requestBodies {
rerankResp, err := base.callService(ctx, requestBody, 30)
if err != nil {
return nil, fmt.Errorf("Call rerank model failed: %v\n", err)
}
scores = append(scores, rerankResp...)
}
if len(scores) != len(docs) {
return nil, fmt.Errorf("Call Rerank service failed, %d docs but got %d scores", len(docs), len(scores))
}
return scores, nil
}
func (base *baseModel) callService(ctx context.Context, requestBody []byte, timeoutSec int64) ([]float32, error) {
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second)
defer cancel()
headers := map[string]string{
"Content-Type": "application/json",
}
body, err := utils.RetrySend(ctx, requestBody, http.MethodPost, base.url, headers, 3)
if err != nil {
return nil, err
}
return base.parseScores(body)
}
type vllmRerankRequest struct {
Query string `json:"query"`
Documents []string `json:"documents"`
}
type vllmRerankResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Usage vllmUsage `json:"usage"`
Results []vllmResult `json:"results"`
}
type vllmUsage struct {
TotalTokens int `json:"total_tokens"`
}
type vllmResult struct {
Index int `json:"index"`
Document vllmDocument `json:"document"`
RelevanceScore float32 `json:"relevance_score"`
}
type vllmDocument struct {
Text string `json:"text"`
}
type vllmProvider struct {
baseModel
}
func newVllmProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
if !isEnable(conf, function.EnableVllmEnvStr) {
return nil, fmt.Errorf("Vllm rerank is disabled")
}
endpoint, maxBatch, err := parseParams(params)
if err != nil {
return nil, err
}
base, _ := url.Parse(endpoint)
base.Path = "/v2/rerank"
model := baseModel{
url: base.String(),
maxBatch: maxBatch,
queryKey: "query",
docKey: "documents",
parseScores: func(body []byte) ([]float32, error) {
var rerankResp vllmRerankResponse
if err := json.Unmarshal(body, &rerankResp); err != nil {
return nil, fmt.Errorf("Rerank error, parsing vllm response failed: %v", err)
}
sort.Slice(rerankResp.Results, func(i, j int) bool {
return rerankResp.Results[i].Index < rerankResp.Results[j].Index
})
scores := make([]float32, 0, len(rerankResp.Results))
for _, result := range rerankResp.Results {
scores = append(scores, result.RelevanceScore)
}
return scores, nil
},
}
return &vllmProvider{baseModel: model}, nil
}
type teiProvider struct {
baseModel
}
type TEIResponse struct {
Index int `json:"index"`
Score float32 `json:"score"`
}
func newTeiProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
if !isEnable(conf, function.EnableTeiEnvStr) {
return nil, fmt.Errorf("TEI rerank is disabled")
}
endpoint, maxBatch, err := parseParams(params)
if err != nil {
return nil, err
}
base, _ := url.Parse(endpoint)
base.Path = "/rerank"
model := baseModel{
url: base.String(),
maxBatch: maxBatch,
queryKey: "query",
docKey: "texts",
parseScores: func(body []byte) ([]float32, error) {
var results []TEIResponse
if err := json.Unmarshal(body, &results); err != nil {
return nil, fmt.Errorf("Rerank error, parsing TEI response failed: %v", err)
}
sort.Slice(results, func(i, j int) bool {
return results[i].Index < results[j].Index
})
scores := make([]float32, len(results))
for i, result := range results {
scores[i] = result.Score
}
return scores, nil
},
}
return &teiProvider{baseModel: model}, nil
}
func isEnable(conf map[string]string, envKey string) bool {
// milvus.yaml > env
value, exists := conf["enable"]
if exists {
return strings.ToLower(value) == "true"
} else {
return !(strings.ToLower(os.Getenv(envKey)) == "false")
}
}
func parseParams(params []*commonpb.KeyValuePair) (string, int, error) {
endpoint := ""
maxBatch := 32
for _, param := range params {
switch strings.ToLower(param.Key) {
case function.EndpointParamKey:
base, err := url.Parse(param.Value)
if err != nil {
return "", 0, err
}
if base.Scheme != "http" && base.Scheme != "https" {
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
}
if base.Host == "" {
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
}
endpoint = base.String()
case maxBatchKeyName:
if batch, err := strconv.ParseInt(param.Value, 10, 64); err != nil {
return "", 0, fmt.Errorf("Rerank params error, maxBatch: %s is not a number", param.Value)
} else {
maxBatch = int(batch)
}
}
}
if endpoint == "" {
return "", 0, fmt.Errorf("Rerank function lost params endpoint")
}
if maxBatch <= 0 {
return "", 0, fmt.Errorf("Rerank function params max_batch must > 0, but got %d", maxBatch)
}
return endpoint, maxBatch, nil
}
func genRerankRequestBody(query string, documents []string, maxSize int, queryKey string, docKey string) ([][]byte, error) {
requestBodies := [][]byte{}
for i := 0; i < len(documents); i += maxSize {
end := i + maxSize
if end > len(documents) {
end = len(documents)
}
requestBody := map[string]interface{}{
queryKey: query,
docKey: documents[i:end],
}
jsonData, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("Create model rerank request failed, err: %s", err)
}
requestBodies = append(requestBodies, jsonData)
}
return requestBodies, nil
}
func newProvider(params []*commonpb.KeyValuePair) (modelProvider, error) {
for _, param := range params {
if strings.ToLower(param.Key) == providerParamName {
provider := strings.ToLower(param.Value)
conf := paramtable.Get().FunctionCfg.GetRerankModelProviders(provider)
switch provider {
case vllmProviderName:
return newVllmProvider(params, conf)
case teiProviderName:
return newTeiProvider(params, conf)
default:
return nil, fmt.Errorf("Unknow rerank provider:%s", param.Value)
}
}
}
return nil, fmt.Errorf("Lost rerank params:%s ", providerParamName)
}
type ModelFunction[T PKType] struct {
RerankBase
provider modelProvider
queries []string
}
func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
if err != nil {
return nil, err
}
if len(base.GetInputFieldNames()) != 1 {
return nil, fmt.Errorf("Rerank model only supports single input, but gets [%s] input", base.GetInputFieldNames())
}
if base.GetInputFieldTypes()[0] != schemapb.DataType_VarChar {
return nil, fmt.Errorf("Rerank model only support varchar, bug got [%s]", base.GetInputFieldTypes()[0].String())
}
provider, err := newProvider(funcSchema.Params)
if err != nil {
return nil, err
}
queries := []string{}
for _, param := range funcSchema.Params {
if param.Key == queryKeyName {
if err := json.Unmarshal([]byte(param.Value), &queries); err != nil {
return nil, fmt.Errorf("Parse rerank params [queries] failed, err: %v", err)
}
}
}
if len(queries) == 0 {
return nil, fmt.Errorf("Rerank function lost params queries")
}
if base.pkType == schemapb.DataType_Int64 {
return &ModelFunction[int64]{RerankBase: *base, provider: provider, queries: queries}, nil
} else {
return &ModelFunction[string]{RerankBase: *base, provider: provider, queries: queries}, nil
}
}
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) {
uniqueData := make(map[T]string)
for _, col := range cols {
texts := col.data[0].([]string)
ids := col.ids.([]T)
for idx, id := range ids {
if _, ok := uniqueData[id]; !ok {
uniqueData[id] = texts[idx]
}
}
}
ids := make([]T, 0, len(uniqueData))
texts := make([]string, 0, len(uniqueData))
for id, text := range uniqueData {
ids = append(ids, id)
texts = append(texts, text)
}
scores, err := model.provider.rerank(ctx, query, texts)
if err != nil {
return nil, err
}
rerankScores := map[T]float32{}
for idx, id := range ids {
rerankScores[id] = scores[idx]
}
return newIDScores(rerankScores, searchParams), nil
}
func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
if len(model.queries) != int(searchParams.nq) {
return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", searchParams.nq, len(model.queries), model.queries)
}
outputs := newRerankOutputs(searchParams)
for idx, cols := range inputs.data {
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols)
if err != nil {
return nil, err
}
appendResult(outputs, idScore.ids, idScore.scores)
}
return outputs, nil
}