mirror of
https://gitee.com/milvus-io/milvus.git
synced 2026-01-07 19:31:51 +08:00
feat: Support vllm and tei rerank (#41947)
https://github.com/milvus-io/milvus/issues/35856 Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
This commit is contained in:
parent
14563ad2b3
commit
4202c775ba
@ -1327,3 +1327,10 @@ function:
|
||||
voyageai:
|
||||
credential: # The name in the crendential configuration item
|
||||
url: # Your voyageai embedding url, Default is the official embedding url
|
||||
rerank:
|
||||
model:
|
||||
providers:
|
||||
tei:
|
||||
enable: true # Whether to enable TEI rerank service
|
||||
vllm:
|
||||
enable: true # Whether to enable vllm rerank service
|
||||
|
||||
@ -646,6 +646,17 @@ func mergeIDs(idsList []*schemapb.IDs) (*schemapb.IDs, int) {
|
||||
|
||||
func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, multipleMilvusResults []*milvuspb.SearchResults, searchMetrics []string) error {
|
||||
var err error
|
||||
processRerank := func(ctx context.Context, results []*milvuspb.SearchResults) (*milvuspb.SearchResults, error) {
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
|
||||
defer sp.End()
|
||||
|
||||
params := rerank.NewSearchParams(
|
||||
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
|
||||
)
|
||||
return t.functionScore.Process(ctx, params, results)
|
||||
}
|
||||
|
||||
// The first step of hybrid search is without meta information. If rerank requires meta data, we need to do requery.
|
||||
// At this time, outputFields and rerank input_fields will be recalled.
|
||||
// If we want to save memory, we can only recall the rerank input_fields in this step, and recall the output_fields in the third step
|
||||
@ -682,12 +693,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
|
||||
for i := 0; i < len(multipleMilvusResults); i++ {
|
||||
multipleMilvusResults[i].Results.FieldsData = fields[i]
|
||||
}
|
||||
params := rerank.NewSearchParams(
|
||||
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
|
||||
)
|
||||
|
||||
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
|
||||
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
if fields, err := t.reorganizeRequeryResults(ctx, queryResult.GetFieldsData(), []*schemapb.IDs{t.result.Results.Ids}); err != nil {
|
||||
@ -696,11 +702,7 @@ func (t *searchTask) hybridSearchRank(ctx context.Context, span trace.Span, mult
|
||||
t.result.Results.FieldsData = fields[0]
|
||||
}
|
||||
} else {
|
||||
params := rerank.NewSearchParams(
|
||||
t.Nq, t.rankParams.limit, t.rankParams.offset, t.rankParams.roundDecimal,
|
||||
t.rankParams.groupByFieldId, t.rankParams.groupSize, t.rankParams.strictGroupSize, searchMetrics,
|
||||
)
|
||||
if t.result, err = t.functionScore.Process(ctx, params, multipleMilvusResults); err != nil {
|
||||
if t.result, err = processRerank(ctx, multipleMilvusResults); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -823,11 +825,15 @@ func (t *searchTask) searchPostProcess(ctx context.Context, span trace.Span, toR
|
||||
}
|
||||
|
||||
if t.functionScore != nil && len(result.Results.FieldsData) != 0 {
|
||||
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
|
||||
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
|
||||
// rank only returns id and score
|
||||
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
|
||||
return err
|
||||
{
|
||||
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-call-rerank-function-udf")
|
||||
defer sp.End()
|
||||
params := rerank.NewSearchParams(t.Nq, t.SearchRequest.GetTopk(), t.SearchRequest.GetOffset(),
|
||||
t.queryInfos[0].RoundDecimal, t.queryInfos[0].GroupByFieldId, t.queryInfos[0].GroupSize, t.queryInfos[0].StrictGroupSize, []string{metricType})
|
||||
// rank only returns id and score
|
||||
if t.result, err = t.functionScore.Process(ctx, params, []*milvuspb.SearchResults{result}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if !t.needRequery {
|
||||
fields, err := t.reorganizeRequeryResults(ctx, result.Results.FieldsData, []*schemapb.IDs{t.result.Results.Ids})
|
||||
|
||||
@ -109,16 +109,17 @@ const (
|
||||
siliconflowAKEnvStr string = "MILVUSAI_SILICONFLOW_API_KEY"
|
||||
)
|
||||
|
||||
// TEI
|
||||
// TEI and vllm
|
||||
|
||||
const (
|
||||
ingestionPromptParamKey string = "ingestion_prompt"
|
||||
searchPromptParamKey string = "search_prompt"
|
||||
maxClientBatchSizeParamKey string = "max_client_batch_size"
|
||||
truncationDirectionParamKey string = "truncation_direction"
|
||||
endpointParamKey string = "endpoint"
|
||||
EndpointParamKey string = "endpoint"
|
||||
|
||||
enableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
|
||||
EnableTeiEnvStr string = "MILVUSAI_ENABLE_TEI"
|
||||
EnableVllmEnvStr string = "MILVUSAI_ENABLE_VLLM"
|
||||
)
|
||||
|
||||
func parseAKAndURL(credentials *credentials.Credentials, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) {
|
||||
|
||||
@ -140,7 +140,7 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -106,7 +106,7 @@ func (c *CohereEmbedding) Embedding(modelName string, texts []string, inputType
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("bearer %s", c.apiKey),
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -149,7 +149,7 @@ func (c *openAIBase) embedding(url string, headers map[string]string, modelName
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3, 1)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -218,8 +218,8 @@ func TestEmbeddingFailed(t *testing.T) {
|
||||
func TestTimeout(t *testing.T) {
|
||||
var st int32 = 0
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// (Timeout 1s + Wait 1s) * Retry 3
|
||||
time.Sleep(6 * time.Second)
|
||||
// (Timeout 1s + 2s + 4s + Wait 1s * 3)
|
||||
time.Sleep(11 * time.Second)
|
||||
atomic.AddInt32(&st, 1)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
|
||||
@ -121,7 +121,7 @@ func (c *SiliconflowEmbedding) Embedding(modelName string, texts []string, encod
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ func (c *TEIEmbedding) Embedding(texts []string, truncate bool, truncationDirect
|
||||
if c.apiKey != "" {
|
||||
headers["Authorization"] = fmt.Sprintf("Bearer %s", c.apiKey)
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
@ -45,7 +46,7 @@ func send(req *http.Request) ([]byte, error) {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int, retryDelay int) ([]byte, error) {
|
||||
func RetrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int) ([]byte, error) {
|
||||
var err error
|
||||
var body []byte
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
@ -60,7 +61,9 @@ func RetrySend(ctx context.Context, data []byte, httpMethod string, url string,
|
||||
if err == nil {
|
||||
return body, nil
|
||||
}
|
||||
time.Sleep(time.Duration(retryDelay) * time.Second)
|
||||
backoffDelay := 1 << uint(i) * time.Second
|
||||
jitter := time.Duration(rand.Int63n(int64(backoffDelay / 4)))
|
||||
time.Sleep(backoffDelay + jitter)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ func (c *VertexAIEmbedding) Embedding(modelName string, texts []string, dim int6
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", token),
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -138,7 +138,7 @@ func (c *VoyageAIEmbedding) Embedding(modelName string, texts []string, dim int,
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3, 1)
|
||||
body, err := utils.RetrySend(ctx, data, http.MethodPost, c.url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -55,34 +55,17 @@ type DecayFunction[T PKType, R int32 | int64 | float32 | float64] struct {
|
||||
}
|
||||
|
||||
func newDecayFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
pkType := schemapb.DataType_None
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
pkType = field.DataType
|
||||
}
|
||||
}
|
||||
|
||||
if pkType == schemapb.DataType_None {
|
||||
return nil, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
|
||||
}
|
||||
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false, pkType)
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(base.GetInputFieldNames()) != 1 {
|
||||
return nil, fmt.Errorf("Decay function only supoorts single input, but gets [%s] input", base.GetInputFieldNames())
|
||||
return nil, fmt.Errorf("Decay function only supports single input, but gets [%s] input", base.GetInputFieldNames())
|
||||
}
|
||||
|
||||
var inputType schemapb.DataType
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.Name == base.GetInputFieldNames()[0] {
|
||||
inputType = field.DataType
|
||||
}
|
||||
}
|
||||
|
||||
if pkType == schemapb.DataType_Int64 {
|
||||
inputType := base.GetInputFieldTypes()[0]
|
||||
if base.pkType == schemapb.DataType_Int64 {
|
||||
switch inputType {
|
||||
case schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32:
|
||||
return newFunction[int64, int32](base, funcSchema)
|
||||
@ -160,7 +143,7 @@ func newFunction[T PKType, R int32 | int64 | float32 | float64](base *RerankBase
|
||||
}
|
||||
|
||||
if decayFunc.decay <= 0 || decayFunc.decay >= 1 {
|
||||
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.offset)
|
||||
return nil, fmt.Errorf("Decay function param: decay must 0 < decay < 1, but got %f", decayFunc.decay)
|
||||
}
|
||||
|
||||
switch decayFunc.functionName {
|
||||
|
||||
@ -90,7 +90,7 @@ func (s *DecayFunctionSuite) TestNewDecayErrors() {
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"ts", "pk"}
|
||||
_, err := newDecayFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Decay function only supoorts single input, but gets")
|
||||
s.ErrorContains(err, "Decay function only supports single input, but gets")
|
||||
}
|
||||
|
||||
{
|
||||
|
||||
@ -32,6 +32,7 @@ import (
|
||||
|
||||
const (
|
||||
decayFunctionName string = "decay"
|
||||
modelFunctionName string = "model"
|
||||
)
|
||||
|
||||
type SearchParams struct {
|
||||
@ -92,8 +93,10 @@ func createFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.
|
||||
switch rerankerName {
|
||||
case decayFunctionName:
|
||||
rerankFunc, newRerankErr = newDecayFunction(collSchema, funcSchema)
|
||||
case modelFunctionName:
|
||||
rerankFunc, newRerankErr = newModelFunction(collSchema, funcSchema)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s]", rerankerName, decayFunctionName)
|
||||
return nil, fmt.Errorf("Unsupported rerank function: [%s] , list of supported [%s,%s]", rerankerName, decayFunctionName, modelFunctionName)
|
||||
}
|
||||
|
||||
if newRerankErr != nil {
|
||||
|
||||
378
internal/util/function/rerank/model_function.go
Normal file
378
internal/util/function/rerank/model_function.go
Normal file
@ -0,0 +1,378 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/internal/util/function/models/utils"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
const (
|
||||
providerParamName string = "provider"
|
||||
vllmProviderName string = "vllm"
|
||||
teiProviderName string = "tei"
|
||||
|
||||
queryKeyName string = "queries"
|
||||
maxBatchKeyName string = "max_batch"
|
||||
)
|
||||
|
||||
type modelProvider interface {
|
||||
rerank(context.Context, string, []string) ([]float32, error)
|
||||
getURL() string
|
||||
}
|
||||
|
||||
type baseModel struct {
|
||||
url string
|
||||
maxBatch int
|
||||
|
||||
queryKey string
|
||||
docKey string
|
||||
|
||||
parseScores func([]byte) ([]float32, error)
|
||||
}
|
||||
|
||||
func (base *baseModel) getURL() string {
|
||||
return base.url
|
||||
}
|
||||
|
||||
func (base *baseModel) rerank(ctx context.Context, query string, docs []string) ([]float32, error) {
|
||||
requestBodies, err := genRerankRequestBody(query, docs, base.maxBatch, base.queryKey, base.docKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
scores := []float32{}
|
||||
for _, requestBody := range requestBodies {
|
||||
rerankResp, err := base.callService(ctx, requestBody, 30)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Call rerank model failed: %v\n", err)
|
||||
}
|
||||
scores = append(scores, rerankResp...)
|
||||
}
|
||||
|
||||
if len(scores) != len(docs) {
|
||||
return nil, fmt.Errorf("Call Rerank service failed, %d docs but got %d scores", len(docs), len(scores))
|
||||
}
|
||||
return scores, nil
|
||||
}
|
||||
|
||||
func (base *baseModel) callService(ctx context.Context, requestBody []byte, timeoutSec int64) ([]float32, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Duration(timeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
headers := map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
body, err := utils.RetrySend(ctx, requestBody, http.MethodPost, base.url, headers, 3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return base.parseScores(body)
|
||||
}
|
||||
|
||||
type vllmRerankRequest struct {
|
||||
Query string `json:"query"`
|
||||
Documents []string `json:"documents"`
|
||||
}
|
||||
|
||||
type vllmRerankResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Usage vllmUsage `json:"usage"`
|
||||
Results []vllmResult `json:"results"`
|
||||
}
|
||||
|
||||
type vllmUsage struct {
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type vllmResult struct {
|
||||
Index int `json:"index"`
|
||||
Document vllmDocument `json:"document"`
|
||||
RelevanceScore float32 `json:"relevance_score"`
|
||||
}
|
||||
|
||||
type vllmDocument struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type vllmProvider struct {
|
||||
baseModel
|
||||
}
|
||||
|
||||
func newVllmProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
|
||||
if !isEnable(conf, function.EnableVllmEnvStr) {
|
||||
return nil, fmt.Errorf("Vllm rerank is disabled")
|
||||
}
|
||||
endpoint, maxBatch, err := parseParams(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base, _ := url.Parse(endpoint)
|
||||
base.Path = "/v2/rerank"
|
||||
model := baseModel{
|
||||
url: base.String(),
|
||||
maxBatch: maxBatch,
|
||||
queryKey: "query",
|
||||
docKey: "documents",
|
||||
parseScores: func(body []byte) ([]float32, error) {
|
||||
var rerankResp vllmRerankResponse
|
||||
if err := json.Unmarshal(body, &rerankResp); err != nil {
|
||||
return nil, fmt.Errorf("Rerank error, parsing vllm response failed: %v", err)
|
||||
}
|
||||
|
||||
sort.Slice(rerankResp.Results, func(i, j int) bool {
|
||||
return rerankResp.Results[i].Index < rerankResp.Results[j].Index
|
||||
})
|
||||
|
||||
scores := make([]float32, 0, len(rerankResp.Results))
|
||||
for _, result := range rerankResp.Results {
|
||||
scores = append(scores, result.RelevanceScore)
|
||||
}
|
||||
|
||||
return scores, nil
|
||||
},
|
||||
}
|
||||
return &vllmProvider{baseModel: model}, nil
|
||||
}
|
||||
|
||||
type teiProvider struct {
|
||||
baseModel
|
||||
}
|
||||
|
||||
type TEIResponse struct {
|
||||
Index int `json:"index"`
|
||||
Score float32 `json:"score"`
|
||||
}
|
||||
|
||||
func newTeiProvider(params []*commonpb.KeyValuePair, conf map[string]string) (modelProvider, error) {
|
||||
if !isEnable(conf, function.EnableTeiEnvStr) {
|
||||
return nil, fmt.Errorf("TEI rerank is disabled")
|
||||
}
|
||||
endpoint, maxBatch, err := parseParams(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
base, _ := url.Parse(endpoint)
|
||||
base.Path = "/rerank"
|
||||
model := baseModel{
|
||||
url: base.String(),
|
||||
maxBatch: maxBatch,
|
||||
queryKey: "query",
|
||||
docKey: "texts",
|
||||
parseScores: func(body []byte) ([]float32, error) {
|
||||
var results []TEIResponse
|
||||
if err := json.Unmarshal(body, &results); err != nil {
|
||||
return nil, fmt.Errorf("Rerank error, parsing TEI response failed: %v", err)
|
||||
}
|
||||
sort.Slice(results, func(i, j int) bool {
|
||||
return results[i].Index < results[j].Index
|
||||
})
|
||||
scores := make([]float32, len(results))
|
||||
for i, result := range results {
|
||||
scores[i] = result.Score
|
||||
}
|
||||
return scores, nil
|
||||
},
|
||||
}
|
||||
return &teiProvider{baseModel: model}, nil
|
||||
}
|
||||
|
||||
func isEnable(conf map[string]string, envKey string) bool {
|
||||
// milvus.yaml > env
|
||||
value, exists := conf["enable"]
|
||||
if exists {
|
||||
return strings.ToLower(value) == "true"
|
||||
} else {
|
||||
return !(strings.ToLower(os.Getenv(envKey)) == "false")
|
||||
}
|
||||
}
|
||||
|
||||
func parseParams(params []*commonpb.KeyValuePair) (string, int, error) {
|
||||
endpoint := ""
|
||||
maxBatch := 32
|
||||
for _, param := range params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case function.EndpointParamKey:
|
||||
base, err := url.Parse(param.Value)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
if base.Scheme != "http" && base.Scheme != "https" {
|
||||
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
|
||||
}
|
||||
if base.Host == "" {
|
||||
return "", 0, fmt.Errorf("Rerank endpoint: [%s] is not a valid http/https link", param.Value)
|
||||
}
|
||||
endpoint = base.String()
|
||||
case maxBatchKeyName:
|
||||
if batch, err := strconv.ParseInt(param.Value, 10, 64); err != nil {
|
||||
return "", 0, fmt.Errorf("Rerank params error, maxBatch: %s is not a number", param.Value)
|
||||
} else {
|
||||
maxBatch = int(batch)
|
||||
}
|
||||
}
|
||||
}
|
||||
if endpoint == "" {
|
||||
return "", 0, fmt.Errorf("Rerank function lost params endpoint")
|
||||
}
|
||||
if maxBatch <= 0 {
|
||||
return "", 0, fmt.Errorf("Rerank function params max_batch must > 0, but got %d", maxBatch)
|
||||
}
|
||||
return endpoint, maxBatch, nil
|
||||
}
|
||||
|
||||
func genRerankRequestBody(query string, documents []string, maxSize int, queryKey string, docKey string) ([][]byte, error) {
|
||||
requestBodies := [][]byte{}
|
||||
for i := 0; i < len(documents); i += maxSize {
|
||||
end := i + maxSize
|
||||
if end > len(documents) {
|
||||
end = len(documents)
|
||||
}
|
||||
requestBody := map[string]interface{}{
|
||||
queryKey: query,
|
||||
docKey: documents[i:end],
|
||||
}
|
||||
jsonData, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Create model rerank request failed, err: %s", err)
|
||||
}
|
||||
requestBodies = append(requestBodies, jsonData)
|
||||
}
|
||||
return requestBodies, nil
|
||||
}
|
||||
|
||||
func newProvider(params []*commonpb.KeyValuePair) (modelProvider, error) {
|
||||
for _, param := range params {
|
||||
if strings.ToLower(param.Key) == providerParamName {
|
||||
provider := strings.ToLower(param.Value)
|
||||
conf := paramtable.Get().FunctionCfg.GetRerankModelProviders(provider)
|
||||
switch provider {
|
||||
case vllmProviderName:
|
||||
return newVllmProvider(params, conf)
|
||||
case teiProviderName:
|
||||
return newTeiProvider(params, conf)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unknow rerank provider:%s", param.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("Lost rerank params:%s ", providerParamName)
|
||||
}
|
||||
|
||||
type ModelFunction[T PKType] struct {
|
||||
RerankBase
|
||||
|
||||
provider modelProvider
|
||||
queries []string
|
||||
}
|
||||
|
||||
func newModelFunction(collSchema *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema) (Reranker, error) {
|
||||
base, err := newRerankBase(collSchema, funcSchema, decayFunctionName, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(base.GetInputFieldNames()) != 1 {
|
||||
return nil, fmt.Errorf("Rerank model only supports single input, but gets [%s] input", base.GetInputFieldNames())
|
||||
}
|
||||
|
||||
if base.GetInputFieldTypes()[0] != schemapb.DataType_VarChar {
|
||||
return nil, fmt.Errorf("Rerank model only support varchar, bug got [%s]", base.GetInputFieldTypes()[0].String())
|
||||
}
|
||||
|
||||
provider, err := newProvider(funcSchema.Params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
queries := []string{}
|
||||
for _, param := range funcSchema.Params {
|
||||
if param.Key == queryKeyName {
|
||||
if err := json.Unmarshal([]byte(param.Value), &queries); err != nil {
|
||||
return nil, fmt.Errorf("Parse rerank params [queries] failed, err: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(queries) == 0 {
|
||||
return nil, fmt.Errorf("Rerank function lost params queries")
|
||||
}
|
||||
|
||||
if base.pkType == schemapb.DataType_Int64 {
|
||||
return &ModelFunction[int64]{RerankBase: *base, provider: provider, queries: queries}, nil
|
||||
} else {
|
||||
return &ModelFunction[string]{RerankBase: *base, provider: provider, queries: queries}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (model *ModelFunction[T]) processOneSearchData(ctx context.Context, searchParams *SearchParams, query string, cols []*columns) (*IDScores[T], error) {
|
||||
uniqueData := make(map[T]string)
|
||||
for _, col := range cols {
|
||||
texts := col.data[0].([]string)
|
||||
ids := col.ids.([]T)
|
||||
for idx, id := range ids {
|
||||
if _, ok := uniqueData[id]; !ok {
|
||||
uniqueData[id] = texts[idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
ids := make([]T, 0, len(uniqueData))
|
||||
texts := make([]string, 0, len(uniqueData))
|
||||
for id, text := range uniqueData {
|
||||
ids = append(ids, id)
|
||||
texts = append(texts, text)
|
||||
}
|
||||
scores, err := model.provider.rerank(ctx, query, texts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rerankScores := map[T]float32{}
|
||||
for idx, id := range ids {
|
||||
rerankScores[id] = scores[idx]
|
||||
}
|
||||
return newIDScores(rerankScores, searchParams), nil
|
||||
}
|
||||
|
||||
func (model *ModelFunction[T]) Process(ctx context.Context, searchParams *SearchParams, inputs *rerankInputs) (*rerankOutputs, error) {
|
||||
if len(model.queries) != int(searchParams.nq) {
|
||||
return nil, fmt.Errorf("nq must equal to queries size, but got nq [%d], queries size [%d], queries: [%v]", searchParams.nq, len(model.queries), model.queries)
|
||||
}
|
||||
outputs := newRerankOutputs(searchParams)
|
||||
for idx, cols := range inputs.data {
|
||||
idScore, err := model.processOneSearchData(ctx, searchParams, model.queries[idx], cols)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
appendResult(outputs, idScore.ids, idScore.scores)
|
||||
}
|
||||
return outputs, nil
|
||||
}
|
||||
509
internal/util/function/rerank/model_function_test.go
Normal file
509
internal/util/function/rerank/model_function_test.go
Normal file
@ -0,0 +1,509 @@
|
||||
/*
|
||||
* # Licensed to the LF AI & Data foundation under one
|
||||
* # or more contributor license agreements. See the NOTICE file
|
||||
* # distributed with this work for additional information
|
||||
* # regarding copyright ownership. The ASF licenses this file
|
||||
* # to you under the Apache License, Version 2.0 (the
|
||||
* # "License"); you may not use this file except in compliance
|
||||
* # with the License. You may obtain a copy of the License at
|
||||
* #
|
||||
* # http://www.apache.org/licenses/LICENSE-2.0
|
||||
* #
|
||||
* # Unless required by applicable law or agreed to in writing, software
|
||||
* # distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* # See the License for the specific language governing permissions and
|
||||
* # limitations under the License.
|
||||
*/
|
||||
|
||||
package rerank
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||
"github.com/milvus-io/milvus/internal/util/function"
|
||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||
)
|
||||
|
||||
func TestRerankModel(t *testing.T) {
|
||||
suite.Run(t, new(RerankModelSuite))
|
||||
}
|
||||
|
||||
type RerankModelSuite struct {
|
||||
suite.Suite
|
||||
}
|
||||
|
||||
func (s *RerankModelSuite) SetupTest() {
|
||||
paramtable.Init()
|
||||
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RerankModelSuite) TestNewProvider() {
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "unknown"},
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "Unknow rerank provider")
|
||||
}
|
||||
|
||||
{
|
||||
_, err := newProvider([]*commonpb.KeyValuePair{})
|
||||
s.ErrorContains(err, "Lost rerank params")
|
||||
}
|
||||
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: "illegal"},
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "is not a valid http/https link")
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: "http://"},
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "is not a valid http/https link")
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "Rerank function lost params endpoint")
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
|
||||
{Key: maxBatchKeyName, Value: "-1"},
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "Rerank function params max_batch must > 0")
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
|
||||
{Key: maxBatchKeyName, Value: "NotNum"},
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "Rerank params error, maxBatch")
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
s.Equal(provder.getURL(), "http://localhost:80/v2/rerank")
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "tei"},
|
||||
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
s.Equal(provder.getURL(), "http://localhost:80/rerank")
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "tei"},
|
||||
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
|
||||
}
|
||||
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
||||
key := "tei.enable"
|
||||
return map[string]string{
|
||||
key: "false",
|
||||
}
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "TEI rerank is disabled")
|
||||
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
os.Setenv(function.EnableTeiEnvStr, "false")
|
||||
_, err = newProvider(params)
|
||||
s.ErrorContains(err, "TEI rerank is disabled")
|
||||
os.Unsetenv(function.EnableTeiEnvStr)
|
||||
}
|
||||
{
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
|
||||
}
|
||||
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
||||
key := "vllm.enable"
|
||||
return map[string]string{
|
||||
key: "false",
|
||||
}
|
||||
}
|
||||
_, err := newProvider(params)
|
||||
s.ErrorContains(err, "Vllm rerank is disabled")
|
||||
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
||||
return map[string]string{}
|
||||
}
|
||||
os.Setenv(function.EnableVllmEnvStr, "false")
|
||||
_, err = newProvider(params)
|
||||
s.ErrorContains(err, "Vllm rerank is disabled")
|
||||
os.Unsetenv(function.EnableVllmEnvStr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RerankModelSuite) TestCallVllm() {
|
||||
{
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
||||
s.ErrorContains(err, "Call Rerank service failed, 3 docs but got 2 scores")
|
||||
}
|
||||
{
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
w.Write([]byte(`{'object': 'error', 'message': 'The model vllm-test does not exist.', 'type': 'NotFoundError', 'param': None, 'code': 404}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
||||
s.ErrorContains(err, "Call rerank model failed")
|
||||
}
|
||||
{
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`Not json data`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
||||
s.ErrorContains(err, "Rerank error, parsing vllm response failed")
|
||||
}
|
||||
{
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.0}, {"index": 2, "relevance_score": 0.2}, {"index": 1, "relevance_score": 0.1}]}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
||||
s.NoError(err)
|
||||
s.Equal([]float32{0.0, 0.1, 0.2}, scores)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RerankModelSuite) TestCallTEI() {
|
||||
{
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`[{"index":0,"score":0.0},{"index":1,"score":0.2}, {"index":2,"score":0.1}]`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "tei"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
||||
s.NoError(err)
|
||||
s.Equal([]float32{0.0, 0.2, 0.1}, scores)
|
||||
}
|
||||
{
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`not json`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
params := []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "tei"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
}
|
||||
provder, err := newProvider(params)
|
||||
s.NoError(err)
|
||||
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
||||
s.ErrorContains(err, "Rerank error, parsing TEI response failed")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RerankModelSuite) TestNewModelFunction() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
||||
},
|
||||
}
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"text"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "tei"},
|
||||
{Key: function.EndpointParamKey, Value: "http://localhost:80"},
|
||||
{Key: queryKeyName, Value: `["q1"]`},
|
||||
},
|
||||
}
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"text", "ts"}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Rerank model only supports single input")
|
||||
}
|
||||
{
|
||||
functionSchema.InputFieldNames = []string{"ts"}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Rerank model only support varchar")
|
||||
functionSchema.InputFieldNames = []string{"text"}
|
||||
}
|
||||
{
|
||||
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `NotJson`}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Parse rerank params [queries] failed")
|
||||
}
|
||||
{
|
||||
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `[]`}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Rerank function lost params queries")
|
||||
}
|
||||
{
|
||||
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `["test"]`}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
}
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: providerParamName, Value: `notExist`}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Unknow rerank provider")
|
||||
}
|
||||
{
|
||||
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: "NotExist", Value: `notExist`}
|
||||
_, err := newModelFunction(schema, functionSchema)
|
||||
s.ErrorContains(err, "Lost rerank params")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RerankModelSuite) TestRerankProcess() {
|
||||
schema := &schemapb.CollectionSchema{
|
||||
Name: "test",
|
||||
Fields: []*schemapb.FieldSchema{
|
||||
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
||||
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
||||
{
|
||||
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
||||
TypeParams: []*commonpb.KeyValuePair{
|
||||
{Key: "dim", Value: "4"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
{
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`))
|
||||
}))
|
||||
defer ts.Close()
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"text"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "tei"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
{Key: queryKeyName, Value: `["q1"]`},
|
||||
},
|
||||
}
|
||||
|
||||
// empty
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
s.NoError(err)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
s.Equal([]int64{}, ret.searchResultData.Topks)
|
||||
}
|
||||
|
||||
// no input field exist
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
|
||||
s.NoError(err)
|
||||
|
||||
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
s.ErrorContains(err, "Search reaults mismatch rerank inputs")
|
||||
}
|
||||
}
|
||||
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
req := map[string]any{}
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
defer r.Body.Close()
|
||||
json.Unmarshal(body, &req)
|
||||
ret := map[string][]map[string]any{}
|
||||
ret["results"] = []map[string]any{}
|
||||
for i := range req["documents"].([]any) {
|
||||
d := map[string]any{}
|
||||
d["index"] = i
|
||||
d["relevance_score"] = float32(i) / 10
|
||||
ret["results"] = append(ret["results"], d)
|
||||
}
|
||||
jsonData, _ := json.Marshal(ret)
|
||||
w.Write(jsonData)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
// singleSearchResultData
|
||||
functionSchema := &schemapb.FunctionSchema{
|
||||
Name: "test",
|
||||
Type: schemapb.FunctionType_Rerank,
|
||||
InputFieldNames: []string{"text"},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: providerParamName, Value: "vllm"},
|
||||
{Key: function.EndpointParamKey, Value: ts.URL},
|
||||
{Key: queryKeyName, Value: `["q1", "q2"]`},
|
||||
},
|
||||
}
|
||||
|
||||
{
|
||||
nq := int64(1)
|
||||
{
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
_, err = f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
s.ErrorContains(err, "nq must equal to queries size, but got nq [1], queries size [2]")
|
||||
}
|
||||
{
|
||||
functionSchema.Params[2].Value = `["q1"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 0, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
}
|
||||
|
||||
{
|
||||
nq := int64(3)
|
||||
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
}
|
||||
}
|
||||
|
||||
// // multipSearchResultData
|
||||
// has empty inputs
|
||||
{
|
||||
nq := int64(1)
|
||||
functionSchema.Params[2].Value = `["q1"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
// empty
|
||||
data2 := genSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE"}}, inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
}
|
||||
// nq = 1
|
||||
{
|
||||
nq := int64(1)
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
// ts/id data: 0 - 9
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
// ts/id data: 0 - 3
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, -1, -1, 1, false, []string{"COSINE", "COSINE"}}, inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
}
|
||||
// // nq = 3
|
||||
{
|
||||
nq := int64(3)
|
||||
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
|
||||
f, err := newModelFunction(schema, functionSchema)
|
||||
s.NoError(err)
|
||||
data1 := genSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
||||
data2 := genSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
||||
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs())
|
||||
ret, err := f.Process(context.Background(), &SearchParams{nq, 3, 2, 1, -1, 1, false, []string{"COSINE", "COSINE", "COSINE"}}, inputs)
|
||||
s.NoError(err)
|
||||
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
||||
s.Equal(int64(3), ret.searchResultData.TopK)
|
||||
}
|
||||
}
|
||||
@ -46,12 +46,18 @@ type RerankBase struct {
|
||||
pkType schemapb.DataType
|
||||
inputFieldIDs []int64
|
||||
inputFieldNames []string
|
||||
inputFieldTypes []schemapb.DataType
|
||||
|
||||
// TODO: The parameter is passed to the reranker, and the reranker decides whether to implement the parameter
|
||||
searchParams *searchParams
|
||||
}
|
||||
|
||||
func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool, pkType schemapb.DataType) (*RerankBase, error) {
|
||||
func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.FunctionSchema, rerankerName string, isSupportGroup bool) (*RerankBase, error) {
|
||||
pkType, err := getPKType(coll)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
base := RerankBase{
|
||||
inputFieldNames: funcSchema.InputFieldNames,
|
||||
rerankerName: rerankerName,
|
||||
@ -82,6 +88,7 @@ func newRerankBase(coll *schemapb.CollectionSchema, funcSchema *schemapb.Functio
|
||||
return nil, fmt.Errorf("Function input field cannot be nullable: field %s", inputField.GetName())
|
||||
}
|
||||
base.inputFieldIDs = append(base.inputFieldIDs, inputField.FieldID)
|
||||
base.inputFieldTypes = append(base.inputFieldTypes, inputField.DataType)
|
||||
}
|
||||
return &base, nil
|
||||
}
|
||||
@ -90,6 +97,10 @@ func (base *RerankBase) GetInputFieldNames() []string {
|
||||
return base.inputFieldNames
|
||||
}
|
||||
|
||||
func (base *RerankBase) GetInputFieldTypes() []schemapb.DataType {
|
||||
return base.inputFieldTypes
|
||||
}
|
||||
|
||||
func (base *RerankBase) GetInputFieldIDs() []int64 {
|
||||
return base.inputFieldIDs
|
||||
}
|
||||
|
||||
@ -236,7 +236,7 @@ func getField(inputField *schemapb.FieldData, start int64, size int64) (any, err
|
||||
return inputField.GetScalars().GetBoolData().Data[start : start+size], nil
|
||||
}
|
||||
return []bool{}, nil
|
||||
case schemapb.DataType_String:
|
||||
case schemapb.DataType_String, schemapb.DataType_VarChar:
|
||||
if inputField.GetScalars() != nil && inputField.GetScalars().GetStringData() != nil {
|
||||
return inputField.GetScalars().GetStringData().Data[start : start+size], nil
|
||||
}
|
||||
@ -285,3 +285,17 @@ func maxMerge[T PKType](cols []*columns) map[T]float32 {
|
||||
}
|
||||
return srcScores
|
||||
}
|
||||
|
||||
func getPKType(collSchema *schemapb.CollectionSchema) (schemapb.DataType, error) {
|
||||
pkType := schemapb.DataType_None
|
||||
for _, field := range collSchema.Fields {
|
||||
if field.IsPrimaryKey {
|
||||
pkType = field.DataType
|
||||
}
|
||||
}
|
||||
|
||||
if pkType == schemapb.DataType_None {
|
||||
return pkType, fmt.Errorf("Collection %s can not found pk field", collSchema.Name)
|
||||
}
|
||||
return pkType, nil
|
||||
}
|
||||
|
||||
@ -47,7 +47,7 @@ type TeiEmbeddingProvider struct {
|
||||
}
|
||||
|
||||
func createTEIEmbeddingClient(apiKey string, endpoint string) (*tei.TEIEmbedding, error) {
|
||||
enable := os.Getenv(enableTeiEnvStr)
|
||||
enable := os.Getenv(EnableTeiEnvStr)
|
||||
if strings.ToLower(enable) == "false" {
|
||||
return nil, errors.New("TEI model serving is not enabled")
|
||||
}
|
||||
@ -69,7 +69,7 @@ func NewTEIEmbeddingProvider(fieldSchema *schemapb.FieldSchema, functionSchema *
|
||||
|
||||
for _, param := range functionSchema.Params {
|
||||
switch strings.ToLower(param.Key) {
|
||||
case endpointParamKey:
|
||||
case EndpointParamKey:
|
||||
endpoint = param.Value
|
||||
case ingestionPromptParamKey:
|
||||
ingestionPrompt = param.Value
|
||||
|
||||
@ -70,7 +70,7 @@ func createTEIProvider(url string, schema *schemapb.FieldSchema, providerName st
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: credentialParamKey, Value: "mock"},
|
||||
{Key: endpointParamKey, Value: url},
|
||||
{Key: EndpointParamKey, Value: url},
|
||||
{Key: ingestionPromptParamKey, Value: "doc:"},
|
||||
{Key: searchPromptParamKey, Value: "query:"},
|
||||
},
|
||||
@ -154,8 +154,8 @@ func (s *TEITextEmbeddingProviderSuite) TestCreateTEIEmbeddingClient() {
|
||||
_, err = createTEIEmbeddingClient("", "http://mymock.com")
|
||||
s.NoError(err)
|
||||
|
||||
os.Setenv(enableTeiEnvStr, "false")
|
||||
defer os.Unsetenv(enableTeiEnvStr)
|
||||
os.Setenv(EnableTeiEnvStr, "false")
|
||||
defer os.Unsetenv(EnableTeiEnvStr)
|
||||
_, err = createTEIEmbeddingClient("", "http://mymock.com")
|
||||
s.Error(err)
|
||||
}
|
||||
@ -170,7 +170,7 @@ func (s *TEITextEmbeddingProviderSuite) TestNewTEIEmbeddingProvider() {
|
||||
OutputFieldIds: []int64{102},
|
||||
Params: []*commonpb.KeyValuePair{
|
||||
{Key: credentialParamKey, Value: "mock"},
|
||||
{Key: endpointParamKey, Value: "http://mymock.com"},
|
||||
{Key: EndpointParamKey, Value: "http://mymock.com"},
|
||||
},
|
||||
}
|
||||
provider, err := NewTEIEmbeddingProvider(s.schema.Fields[2], functionSchema, map[string]string{}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"}))
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
|
||||
type functionConfig struct {
|
||||
TextEmbeddingProviders ParamGroup `refreshable:"true"`
|
||||
RerankModelProviders ParamGroup `refreshable:"true"`
|
||||
}
|
||||
|
||||
func (p *functionConfig) init(base *BaseTable) {
|
||||
@ -73,6 +74,23 @@ func (p *functionConfig) init(base *BaseTable) {
|
||||
},
|
||||
}
|
||||
p.TextEmbeddingProviders.Init(base.mgr)
|
||||
|
||||
p.RerankModelProviders = ParamGroup{
|
||||
KeyPrefix: "function.rerank.model.providers.",
|
||||
Version: "2.6.0",
|
||||
Export: true,
|
||||
DocFunc: func(key string) string {
|
||||
switch key {
|
||||
case "tei.enable":
|
||||
return "Whether to enable TEI rerank service"
|
||||
case "vllm.enable":
|
||||
return "Whether to enable vllm rerank service"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
},
|
||||
}
|
||||
p.RerankModelProviders.Init(base.mgr)
|
||||
}
|
||||
|
||||
const (
|
||||
@ -92,3 +110,17 @@ func (p *functionConfig) GetTextEmbeddingProviderConfig(providerName string) map
|
||||
}
|
||||
return matchedParam
|
||||
}
|
||||
|
||||
func (p *functionConfig) GetRerankModelProviders(providerName string) map[string]string {
|
||||
matchedParam := make(map[string]string)
|
||||
|
||||
params := p.RerankModelProviders.GetValue()
|
||||
prefix := providerName + "."
|
||||
|
||||
for k, v := range params {
|
||||
if strings.HasPrefix(k, prefix) {
|
||||
matchedParam[strings.TrimPrefix(k, prefix)] = v
|
||||
}
|
||||
}
|
||||
return matchedParam
|
||||
}
|
||||
|
||||
@ -925,7 +925,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
|
||||
)
|
||||
vectors_to_search = rng.random((1, dim))
|
||||
error = {ct.err_code: 65535,
|
||||
ct.err_msg: f"Decay function only supoorts single input, but gets [[reranker_field id]] input"}
|
||||
ct.err_msg: f"Decay function only supports single input, but gets [[reranker_field id]] input"}
|
||||
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@ -1053,7 +1053,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
|
||||
)
|
||||
vectors_to_search = rng.random((1, dim))
|
||||
error = {ct.err_code: 65535,
|
||||
ct.err_msg: f"Unsupported rerank function: [1] , list of supported [decay]"}
|
||||
ct.err_msg: f"Unsupported rerank function: [1]"}
|
||||
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@ -1098,7 +1098,7 @@ class TestMilvusClientSearchInvalid(TestMilvusClientV2Base):
|
||||
)
|
||||
vectors_to_search = rng.random((1, dim))
|
||||
error = {ct.err_code: 65535,
|
||||
ct.err_msg: f"Unsupported rerank function: [{not_supported_reranker}] , list of supported [decay]"}
|
||||
ct.err_msg: f"Unsupported rerank function: [{not_supported_reranker}]"}
|
||||
self.search(client, collection_name, vectors_to_search, ranker=my_rerank_fn,
|
||||
check_task=CheckTasks.err_res, check_items=error)
|
||||
|
||||
@ -3997,4 +3997,4 @@ class TestMilvusClientSearchRerankValid(TestMilvusClientV2Base):
|
||||
"nq": len(vectors_to_search),
|
||||
"pk_name": default_primary_key_field_name,
|
||||
"limit": default_limit}
|
||||
)
|
||||
)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user