mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
https://github.com/milvus-io/milvus/issues/35856 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
379 lines
11 KiB
Go
379 lines
11 KiB
Go
/*
|
|
* # Licensed to the LF AI & Data foundation under one
|
|
* # or more contributor license agreements. See the NOTICE file
|
|
* # distributed with this work for additional information
|
|
* # regarding copyright ownership. The ASF licenses this file
|
|
* # to you under the Apache License, Version 2.0 (the
|
|
* # "License"); you may not use this file except in compliance
|
|
* # with the License. You may obtain a copy of the License at
|
|
* #
|
|
* # http://www.apache.org/licenses/LICENSE-2.0
|
|
* #
|
|
* # Unless required by applicable law or agreed to in writing, software
|
|
* # distributed under the License is distributed on an "AS IS" BASIS,
|
|
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* # See the License for the specific language governing permissions and
|
|
* # limitations under the License.
|
|
*/
|
|
|
|
package rerank
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/util/function"
|
|
"github.com/milvus-io/milvus/internal/util/function/models/utils"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
)
|
|
|
|
const (
|
|
providerParamName string = "provider"
|
|
vllmProviderName string = "vllm"
|
|
teiProviderName string = "tei"
|
|
|
|
queryKeyName string = "queries"
|
|
maxBatchKeyName string = "max_batch"
|
|
)
|
|
|
|
type modelProvider interface {
|
|
rerank(context.Context, string, []string) ([]float32, error)
|
|
getURL() string
|
|
}
|
|
|
|
type baseModel struct {
|
|
url string
|
|
maxBatch int
|
|
|
|
queryKey string
|
|
docKey string
|
|
|
|
parseScores func([]byte) ([]float32, error)
|
|
}
|
|
|
|
func (base *baseModel) getURL() string {
|
|
return base.url
|
|
}
|
|
|
|
func (base *baseModel) rerank(ctx context.Context, query string, docs []string) ([]float32, error) {
|
|
requestBodies, err := genRerankRequestBody(query, docs, base.maxBatch, base.queryKey, base.docKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
scores := []float32{}
|
|
for _, requestBody := range requestBodies {
|
|
rerankResp, err := base.callService(ctx, requestBody, 30)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Call rerank model failed: %v\n", err)
|
|
}
|
|
scores = append(scores, rerankResp...)
|
|
}
|
|
|
|
if len(scores) != len(docs) {
|
|
return nil, fmt.Errorf("Call Rerank service failed, %d docs but got %d scores", len(docs), len(scores))
|
|
}
|
|
return scores, nil
|
|
}
|
|
|
|
func (base *baseModel) callService(ctx context.Context, requestBody []byte, timeoutSec int64) ([]float32, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second)
|
|
defer cancel()
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
}
|
|
body, err := utils.RetrySend(ctx, requestBody, http.MethodPost, base.url, headers, 3)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return base.parseScores(body)
|
|
}
|
|
|
|
type vllmRerankRequest struct {
|
|
Query string `json:"query"`
|
|
Documents []string `json:"documents"`
|
|
}
|
|
|
|
type vllmRerankResponse struct {
|
|
ID string `json:"id"`
|
|
Model string `json:"model"`
|
|
Usage vllmUsage `json:"usage"`
|
|
Results []vllmResult `json:"results"`
|
|
}
|
|
|
|
type vllmUsage struct {
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type vllmResult struct {
|
|
Index int `json:"index"`
|
|
Document vllmDocument `json:"document"`
|
|
RelevanceScore float32 `json:"relevance_score"`
|
|
}
|
|
|
|
type vllmDocument struct {
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
type vllmProvider struct {
|
|
baseModel
|
|
}
|
|
|
|
func newVllmProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
|
|
if !isEnable(conf, function.EnableVllmEnvStr) {
|
|
return nil, fmt.Errorf("Vllm rerank is disabled")
|
|
}
|
|
endpoint, maxBatch, err := parseParams(params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
base, _ := url.Parse(endpoint)
|
|
base.Path = "/v2/rerank"
|
|
model := baseModel{
|
|
url: base.String(),
|
|
maxBatch: maxBatch,
|
|
queryKey: "query",
|
|
docKey: "documents",
|
|
parseScores: func(body []byte) ([]float32, error) {
|
|
var rerankResp vllmRerankResponse
|
|
if err := json.Unmarshal(body, &rerankResp); err != nil {
|
|
return nil, fmt.Errorf("Rerank error, parsing vllm response failed: %v", err)
|
|
}
|
|
|
|
sort.Slice(rerankResp.Results, func(i, j int) bool {
|
|
return rerankResp.Results[i].Index < rerankResp.Results[j].Index
|
|
})
|
|
|
|
scores := make([]float32, 0, len(rerankResp.Results))
|
|
for _, result := range rerankResp.Results {
|
|
scores = append(scores, result.RelevanceScore)
|
|
}
|
|
|
|
return scores, nil
|
|
},
|
|
}
|
|
return &vllmProvider{baseModel: model}, nil
|
|
}
|
|
|
|
type teiProvider struct {
|
|
baseModel
|
|
}
|
|
|
|
type TEIResponse struct {
|
|
Index int `json:"index"`
|
|
Score float32 `json:"score"`
|
|
}
|
|
|
|
func newTeiProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
|
|
if !isEnable(conf, function.EnableTeiEnvStr) {
|
|
return nil, fmt.Errorf("TEI rerank is disabled")
|
|
}
|
|
endpoint, maxBatch, err := parseParams(params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
base, _ := url.Parse(endpoint)
|
|
base.Path = "/rerank"
|
|
model := baseModel{
|
|
url: base.String(),
|
|
maxBatch: maxBatch,
|
|
queryKey: "query",
|
|
docKey: "texts",
|
|
parseScores: func(body []byte) ([]float32, error) {
|
|
var results []TEIResponse
|
|
if err := json.Unmarshal(body, &results); err != nil {
|
|
return nil, fmt.Errorf("Rerank error, parsing TEI response failed: %v", err)
|
|
}
|
|
sort.Slice(results, func(i, j int) bool {
|
|
return results[i].Index < results[j].Index
|
|
})
|
|
scores := make([]float32, len(results))
|
|
for i, result := range results {
|
|
scores[i] = result.Score
|
|
}
|
|
return scores, nil
|
|
},
|
|
}
|
|
return &teiProvider{baseModel: model}, nil
|
|
}
|
|
|
|
func isEnable(conf map[string]string, envKey string) bool {
|
|
// milvus.yaml > env
|
|
value, exists := conf["enable"]
|
|
if exists {
|
|
return strings.ToLower(value) == "true"
|
|
} else {
|
|
return !(strings.ToLower(os.Getenv(envKey)) == "false")
|
|
}
|
|
}
|
|
|
|
func parseParams(params []*commonpb.KeyValuePair) (string, int, error) {
|
|
endpoint := ""
|
|
maxBatch := 32
|
|
for _, param := range params {
|
|
switch strings.ToLower(param.Key) {
|
|
case function.EndpointParamKey:
|
|
base, err := url.Parse(param.Value)
|
|
if err != nil {
|
|
return "", 0, err
|
|
}
|
|
if base.Scheme != "http" && base.Scheme != "https" {
|
|
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
|
|
}
|
|
if base.Host == "" {
|
|
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
|
|
}
|
|
endpoint = base.String()
|
|
case maxBatchKeyName:
|
|
if batch, err := strconv.ParseInt(param.Value, 10, 64); err != nil {
|
|
return "", 0, fmt.Errorf("Rerank params error, maxBatch: %s is not a number", param.Value)
|
|
} else {
|
|
maxBatch = int(batch)
|
|
}
|
|
}
|
|
}
|
|
if endpoint == "" {
|
|
return "", 0, fmt.Errorf("Rerank function lost params endpoint")
|
|
}
|
|
if maxBatch <= 0 {
|
|
return "", 0, fmt.Errorf("Rerank function params max_batch must > 0, but got %d", maxBatch)
|
|
}
|
|
return endpoint, maxBatch, nil
|
|
}
|
|
|
|
func genRerankRequestBody(query string, documents []string, maxSize int, queryKey string, docKey string) ([][]byte, error) {
|
|
requestBodies := [][]byte{}
|
|
for i := 0; i < len(documents); i += maxSize {
|
|
end := i + maxSize
|
|
if end > len(documents) {
|
|
end = len(documents)
|
|
}
|
|
requestBody := map[string]interface{}{
|
|
queryKey: query,
|
|
docKey: documents[i:end],
|
|
}
|
|
jsonData, err := json.Marshal(requestBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Create model rerank request failed, err: %s", err)
|
|
}
|
|
requestBodies = append(requestBodies, jsonData)
|
|
}
|
|
return requestBodies, nil
|
|
}
|
|
|
|
func newProvider(params []*commonpb.KeyValuePair) (modelProvider, error) {
|
|
for _, param := range params {
|
|
if strings.ToLower(param.Key) == providerParamName {
|
|
provider := strings.ToLower(param.Value)
|
|
conf := paramtable.Get().FunctionCfg.GetRerankModelProviders(provider)
|
|
switch provider {
|
|
case vllmProviderName:
|
|
return newVllmProvider(params, conf)
|
|
case teiProviderName:
|
|
return newTeiProvider(params, conf)
|
|
default:
|
|
return nil, fmt.Errorf("Unknow rerank provider:%s", param.Value)
|
|
}
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("Lost rerank params:%s ", providerParamName)
|
|
}
|
|
|
|
type ModelFunction[T PKType] struct {
|
|
RerankBase
|
|
|
|
provider modelProvider
|
|
queries []string
|
|
}
|
|
|
|
func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
|
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(base.GetInputFieldNames()) != 1 {
|
|
return nil, fmt.Errorf("Rerank model only supports single input, but gets [%s] input", base.GetInputFieldNames())
|
|
}
|
|
|
|
if base.GetInputFieldTypes()[0] != schemapb.DataType_VarChar {
|
|
return nil, fmt.Errorf("Rerank model only support varchar, bug got [%s]", base.GetInputFieldTypes()[0].String())
|
|
}
|
|
|
|
provider, err := newProvider(funcSchema.Params)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
queries := []string{}
|
|
for _, param := range funcSchema.Params {
|
|
if param.Key == queryKeyName {
|
|
if err := json.Unmarshal([]byte(param.Value), &queries); err != nil {
|
|
return nil, fmt.Errorf("Parse rerank params [queries] failed, err: %v", err)
|
|
}
|
|
}
|
|
}
|
|
if len(queries) == 0 {
|
|
return nil, fmt.Errorf("Rerank function lost params queries")
|
|
}
|
|
|
|
if base.pkType == schemapb.DataType_Int64 {
|
|
return &ModelFunction[int64]{RerankBase: *base, provider: provider, queries: queries}, nil
|
|
} else {
|
|
return &ModelFunction[string]{RerankBase: *base, provider: provider, queries: queries}, nil
|
|
}
|
|
}
|
|
|
|
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) {
|
|
uniqueData := make(map[T]string)
|
|
for _, col := range cols {
|
|
texts := col.data[0].([]string)
|
|
ids := col.ids.([]T)
|
|
for idx, id := range ids {
|
|
if _, ok := uniqueData[id]; !ok {
|
|
uniqueData[id] = texts[idx]
|
|
}
|
|
}
|
|
}
|
|
ids := make([]T, 0, len(uniqueData))
|
|
texts := make([]string, 0, len(uniqueData))
|
|
for id, text := range uniqueData {
|
|
ids = append(ids, id)
|
|
texts = append(texts, text)
|
|
}
|
|
scores, err := model.provider.rerank(ctx, query, texts)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rerankScores := map[T]float32{}
|
|
for idx, id := range ids {
|
|
rerankScores[id] = scores[idx]
|
|
}
|
|
return newIDScores(rerankScores, searchParams), nil
|
|
}
|
|
|
|
func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
|
if len(model.queries) != int(searchParams.nq) {
|
|
return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", searchParams.nq, len(model.queries), model.queries)
|
|
}
|
|
outputs := newRerankOutputs(searchParams)
|
|
for idx, cols := range inputs.data {
|
|
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
appendResult(outputs, idScore.ids, idScore.scores)
|
|
}
|
|
return outputs, nil
|
|
}
|