mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
Cherry-pick from master pr: https://github.com/milvus-io/milvus/pull/44364 issue: https://github.com/milvus-io/milvus/issues/44363 --------- Signed-off-by: PjJinchen <6268414+pj1987111@users.noreply.github.com>
683 lines
23 KiB
Go
683 lines
23 KiB
Go
/*
|
|
* # Licensed to the LF AI & Data foundation under one
|
|
* # or more contributor license agreements. See the NOTICE file
|
|
* # distributed with this work for additional information
|
|
* # regarding copyright ownership. The ASF licenses this file
|
|
* # to you under the Apache License, Version 2.0 (the
|
|
* # "License"); you may not use this file except in compliance
|
|
* # with the License. You may obtain a copy of the License at
|
|
* #
|
|
* # http://www.apache.org/licenses/LICENSE-2.0
|
|
* #
|
|
* # Unless required by applicable law or agreed to in writing, software
|
|
* # distributed under the License is distributed on an "AS IS" BASIS,
|
|
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* # See the License for the specific language governing permissions and
|
|
* # limitations under the License.
|
|
*/
|
|
|
|
package rerank
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/suite"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/util/credentials"
|
|
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
|
"github.com/milvus-io/milvus/internal/util/function/models"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
)
|
|
|
|
func TestRerankModel(t *testing.T) {
|
|
suite.Run(t, new(RerankModelSuite))
|
|
}
|
|
|
|
type RerankModelSuite struct {
|
|
suite.Suite
|
|
}
|
|
|
|
func (s *RerankModelSuite) SetupTest() {
|
|
paramtable.Init()
|
|
paramtable.Get().CredentialCfg.Credential.GetFunc = func() map[string]string {
|
|
return map[string]string{
|
|
"mock.apikey": "mock",
|
|
}
|
|
}
|
|
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
|
return map[string]string{}
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestNewProvider() {
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "unknown"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "Unknow rerank model provider")
|
|
}
|
|
|
|
{
|
|
_, err := newProvider([]*commonpb.KeyValuePair{})
|
|
s.ErrorContains(err, "Lost rerank params")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "Rerank function lost params endpoint")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
{Key: models.MaxClientBatchSizeParamKey, Value: "-1"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "Rerank function params max_batch must > 0")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
{Key: models.MaxClientBatchSizeParamKey, Value: "NotNum"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "is not a valid numbe")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
{Key: models.VllmTruncateParamName, Value: "unknow"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.Error(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
}
|
|
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
|
key := "vllm.enable"
|
|
return map[string]string{
|
|
key: "false",
|
|
}
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "Rerank provider: [vllm] is disabled")
|
|
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
|
return map[string]string{}
|
|
}
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
{Key: models.TruncateParamKey, Value: "unknow"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.Error(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
{Key: models.TruncateParamKey, Value: "true"},
|
|
{Key: models.MaxClientBatchSizeParamKey, Value: "10"},
|
|
{Key: models.TruncationDirectionParamKey, Value: "unknow"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.Error(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
{Key: models.TruncateParamKey, Value: "true"},
|
|
{Key: models.TruncationDirectionParamKey, Value: "Left"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
}
|
|
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
|
key := "tei.enable"
|
|
return map[string]string{
|
|
key: "false",
|
|
}
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "Rerank provider: [tei] is disabled")
|
|
paramtable.Get().FunctionCfg.RerankModelProviders.GetFunc = func() map[string]string {
|
|
return map[string]string{}
|
|
}
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "siliconflow"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "siliconflow rerank model name is required")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "siliconflow"},
|
|
{Key: models.ModelNameParamKey, Value: "siliconflow-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "siliconflow"},
|
|
{Key: models.ModelNameParamKey, Value: "siliconflow-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
{Key: models.MaxClientBatchSizeParamKey, Value: "10"},
|
|
{Key: models.MaxChunksPerDocParamKey, Value: "10"},
|
|
{Key: models.OverlapTokensParamKey, Value: "10"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "cohere"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "cohere rerank model name is required")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "cohere"},
|
|
{Key: models.ModelNameParamKey, Value: "cohere-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
{Key: models.MaxClientBatchSizeParamKey, Value: "10"},
|
|
{Key: models.MaxTKsPerDocParamKey, Value: "10"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "voyageai"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "voyageai rerank model name is required")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "voyageai"},
|
|
{Key: models.ModelNameParamKey, Value: "voyageai-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
{Key: models.MaxClientBatchSizeParamKey, Value: "10"},
|
|
{Key: models.TruncateParamKey, Value: "true"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "ali"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.ErrorContains(err, "ali rerank model name is required")
|
|
}
|
|
{
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "ali"},
|
|
{Key: models.ModelNameParamKey, Value: "ali-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
{Key: models.MaxClientBatchSizeParamKey, Value: "10"},
|
|
}
|
|
_, err := newProvider(params)
|
|
s.NoError(err)
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestCallVllm() {
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
w.Write([]byte(`{'object': 'error', 'message': 'The model vllm-test does not exist.', 'type': 'NotFoundError', 'param': None, 'code': 404}`))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: ts.URL},
|
|
}
|
|
provder, err := newProvider(params)
|
|
s.NoError(err)
|
|
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
|
s.ErrorContains(err, "Call service failed")
|
|
}
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`Not json data`))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: ts.URL},
|
|
}
|
|
provder, err := newProvider(params)
|
|
s.NoError(err)
|
|
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
|
s.ErrorContains(err, "Call service failed")
|
|
}
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.0}, {"index": 2, "relevance_score": 0.2}, {"index": 1, "relevance_score": 0.1}]}`))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: ts.URL},
|
|
}
|
|
provder, err := newProvider(params)
|
|
s.NoError(err)
|
|
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
|
s.NoError(err)
|
|
s.Equal([]float32{0.0, 0.1, 0.2}, scores)
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestCallTEI() {
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`[{"index":0,"score":0.0},{"index":1,"score":0.2}, {"index":2,"score":0.1}]`))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: ts.URL},
|
|
}
|
|
provder, err := newProvider(params)
|
|
s.NoError(err)
|
|
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
|
s.NoError(err)
|
|
s.Equal([]float32{0.0, 0.2, 0.1}, scores)
|
|
}
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`not json`))
|
|
}))
|
|
defer ts.Close()
|
|
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: ts.URL},
|
|
}
|
|
provder, err := newProvider(params)
|
|
s.NoError(err)
|
|
_, err = provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
|
s.ErrorContains(err, "Call service failed")
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestCallSiliconFlow() {
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`))
|
|
}))
|
|
defer ts.Close()
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "siliconflow"},
|
|
{Key: models.ModelNameParamKey, Value: "siliconflow-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
provder, err := newSiliconflowProvider(params, map[string]string{models.URLParamKey: ts.URL}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"}))
|
|
s.NoError(err)
|
|
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2"})
|
|
s.NoError(err)
|
|
s.Equal([]float32{0.1, 0.2}, scores)
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestCallCohere() {
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`))
|
|
}))
|
|
defer ts.Close()
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "cohere"},
|
|
{Key: models.ModelNameParamKey, Value: "cohere-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
provder, err := newCohereProvider(params, map[string]string{models.URLParamKey: ts.URL}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"}))
|
|
s.NoError(err)
|
|
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2"})
|
|
s.NoError(err)
|
|
s.Equal([]float32{0.1, 0.2}, scores)
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestCallVoyageAI() {
|
|
{
|
|
repStr := `{
|
|
"object": "list",
|
|
"data": [{"object": "rerank", "index": 0, "relevance_score": 0.0}, {"object": "rerank", "index": 2, "relevance_score": 0.2}, {"object": "rerank", "index": 1, "relevance_score": 0.1}]
|
|
}`
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(repStr))
|
|
}))
|
|
defer ts.Close()
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "voyageai"},
|
|
{Key: models.ModelNameParamKey, Value: "voyageai-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
provder, err := newVoyageaiProvider(params, map[string]string{models.URLParamKey: ts.URL}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"}))
|
|
s.NoError(err)
|
|
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
|
s.NoError(err)
|
|
s.Equal([]float32{0.0, 0.1, 0.2}, scores)
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestCallAli() {
|
|
{
|
|
repStr := `{"output":{"results":[{"index":0,"relevance_score":0},{"index":1,"relevance_score":0.1},{"index":2,"relevance_score":0.2}]},"usage":{"total_tokens":1},"request_id":"x"}`
|
|
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(repStr))
|
|
}))
|
|
defer ts.Close()
|
|
params := []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "ali"},
|
|
{Key: models.ModelNameParamKey, Value: "ali-test"},
|
|
{Key: models.CredentialParamKey, Value: "mock"},
|
|
}
|
|
provder, err := newAliProvider(params, map[string]string{models.URLParamKey: ts.URL}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"}))
|
|
s.NoError(err)
|
|
scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"})
|
|
s.NoError(err)
|
|
s.Equal([]float32{0.0, 0.1, 0.2}, scores)
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestNewModelFunction() {
|
|
schema := &schemapb.CollectionSchema{
|
|
Name: "test",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
|
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
|
{
|
|
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "dim", Value: "4"},
|
|
},
|
|
},
|
|
{FieldID: 102, Name: "ts", DataType: schemapb.DataType_Int64},
|
|
},
|
|
}
|
|
functionSchema := &schemapb.FunctionSchema{
|
|
Name: "test",
|
|
Type: schemapb.FunctionType_Rerank,
|
|
InputFieldNames: []string{"text"},
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: "http://localhost:80"},
|
|
{Key: queryKeyName, Value: `["q1"]`},
|
|
},
|
|
}
|
|
{
|
|
functionSchema.InputFieldNames = []string{"text", "ts"}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.ErrorContains(err, "Rerank model only supports single input")
|
|
}
|
|
{
|
|
functionSchema.InputFieldNames = []string{"ts"}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.ErrorContains(err, "Rerank model only support varchar")
|
|
functionSchema.InputFieldNames = []string{"text"}
|
|
}
|
|
{
|
|
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `NotJson`}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.ErrorContains(err, "Parse rerank params [queries] failed")
|
|
}
|
|
{
|
|
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `[]`}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.ErrorContains(err, "Rerank function lost params queries")
|
|
}
|
|
{
|
|
functionSchema.Params[2] = &commonpb.KeyValuePair{Key: queryKeyName, Value: `["test"]`}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
schema.Fields[0] = &schemapb.FieldSchema{FieldID: 100, Name: "pk", DataType: schemapb.DataType_VarChar, IsPrimaryKey: true}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
}
|
|
{
|
|
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: providerParamName, Value: `notExist`}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.ErrorContains(err, "Unknow rerank model provider")
|
|
}
|
|
{
|
|
functionSchema.Params[0] = &commonpb.KeyValuePair{Key: "NotExist", Value: `notExist`}
|
|
_, err := newModelFunction(schema, functionSchema)
|
|
s.ErrorContains(err, "Lost rerank params")
|
|
}
|
|
}
|
|
|
|
func (s *RerankModelSuite) TestRerankProcess() {
|
|
schema := &schemapb.CollectionSchema{
|
|
Name: "test",
|
|
Fields: []*schemapb.FieldSchema{
|
|
{FieldID: 100, Name: "pk", DataType: schemapb.DataType_Int64, IsPrimaryKey: true},
|
|
{FieldID: 101, Name: "text", DataType: schemapb.DataType_VarChar},
|
|
{
|
|
FieldID: 102, Name: "vector", DataType: schemapb.DataType_FloatVector,
|
|
TypeParams: []*commonpb.KeyValuePair{
|
|
{Key: "dim", Value: "4"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
{
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte(`{"results": [{"index": 0, "relevance_score": 0.1}, {"index": 1, "relevance_score": 0.2}]}`))
|
|
}))
|
|
defer ts.Close()
|
|
functionSchema := &schemapb.FunctionSchema{
|
|
Name: "test",
|
|
Type: schemapb.FunctionType_Rerank,
|
|
InputFieldNames: []string{"text"},
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "tei"},
|
|
{Key: models.EndpointParamKey, Value: ts.URL},
|
|
{Key: queryKeyName, Value: `["q1"]`},
|
|
},
|
|
}
|
|
|
|
// empty
|
|
{
|
|
nq := int64(1)
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{}, f.GetInputFieldIDs(), false)
|
|
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
|
s.NoError(err)
|
|
s.Equal(int64(3), ret.searchResultData.TopK)
|
|
s.Equal([]int64{}, ret.searchResultData.Topks)
|
|
}
|
|
|
|
// no input field exist
|
|
{
|
|
nq := int64(1)
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
data := embedding.GenSearchResultData(nq, 10, schemapb.DataType_Int64, "noExist", 1000)
|
|
s.NoError(err)
|
|
|
|
_, err = newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
|
s.ErrorContains(err, "Search reaults mismatch rerank inputs")
|
|
}
|
|
}
|
|
|
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
req := map[string]any{}
|
|
body, _ := io.ReadAll(r.Body)
|
|
defer r.Body.Close()
|
|
json.Unmarshal(body, &req)
|
|
ret := map[string][]map[string]any{}
|
|
ret["results"] = []map[string]any{}
|
|
for i := range req["documents"].([]any) {
|
|
d := map[string]any{}
|
|
d["index"] = i
|
|
d["relevance_score"] = float32(i) / 10
|
|
ret["results"] = append(ret["results"], d)
|
|
}
|
|
jsonData, _ := json.Marshal(ret)
|
|
w.Write(jsonData)
|
|
}))
|
|
defer ts.Close()
|
|
|
|
// singleSearchResultData
|
|
functionSchema := &schemapb.FunctionSchema{
|
|
Name: "test",
|
|
Type: schemapb.FunctionType_Rerank,
|
|
InputFieldNames: []string{"text"},
|
|
Params: []*commonpb.KeyValuePair{
|
|
{Key: providerParamName, Value: "vllm"},
|
|
{Key: models.EndpointParamKey, Value: ts.URL},
|
|
{Key: queryKeyName, Value: `["q1", "q2"]`},
|
|
},
|
|
}
|
|
|
|
{
|
|
nq := int64(1)
|
|
{
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
data := embedding.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
|
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
|
_, err = f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
|
s.ErrorContains(err, "nq must equal to queries size, but got nq [1], queries size [2]")
|
|
}
|
|
{
|
|
functionSchema.Params[2].Value = `["q1"]`
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
data := embedding.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
|
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
|
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 0, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
|
s.NoError(err)
|
|
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
|
s.Equal(int64(3), ret.searchResultData.TopK)
|
|
}
|
|
|
|
{
|
|
nq := int64(3)
|
|
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
data := embedding.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
|
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data}, f.GetInputFieldIDs(), false)
|
|
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs)
|
|
s.NoError(err)
|
|
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
|
s.Equal(int64(3), ret.searchResultData.TopK)
|
|
}
|
|
}
|
|
|
|
// // multipSearchResultData
|
|
// has empty inputs
|
|
{
|
|
nq := int64(1)
|
|
functionSchema.Params[2].Value = `["q1"]`
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
data1 := embedding.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
|
// empty
|
|
data2 := embedding.GenSearchResultData(nq, 0, schemapb.DataType_VarChar, "text", 101)
|
|
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
|
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE"}), inputs)
|
|
s.NoError(err)
|
|
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
|
s.Equal(int64(3), ret.searchResultData.TopK)
|
|
}
|
|
// nq = 1
|
|
{
|
|
nq := int64(1)
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
// ts/id data: 0 - 9
|
|
data1 := embedding.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
|
// ts/id data: 0 - 3
|
|
data2 := embedding.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
|
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
|
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, -1, -1, 1, false, "", []string{"COSINE", "COSINE"}), inputs)
|
|
s.NoError(err)
|
|
s.Equal([]int64{3}, ret.searchResultData.Topks)
|
|
s.Equal(int64(3), ret.searchResultData.TopK)
|
|
}
|
|
// // nq = 3
|
|
{
|
|
nq := int64(3)
|
|
functionSchema.Params[2].Value = `["q1", "q2", "q3"]`
|
|
f, err := newModelFunction(schema, functionSchema)
|
|
s.NoError(err)
|
|
data1 := embedding.GenSearchResultData(nq, 10, schemapb.DataType_VarChar, "text", 101)
|
|
data2 := embedding.GenSearchResultData(nq, 4, schemapb.DataType_VarChar, "text", 101)
|
|
inputs, _ := newRerankInputs([]*schemapb.SearchResultData{data1, data2}, f.GetInputFieldIDs(), false)
|
|
ret, err := f.Process(context.Background(), NewSearchParams(nq, 3, 2, 1, -1, 1, false, "", []string{"COSINE", "COSINE", "COSINE"}), inputs)
|
|
s.NoError(err)
|
|
s.Equal([]int64{3, 3, 3}, ret.searchResultData.Topks)
|
|
s.Equal(int64(3), ret.searchResultData.TopK)
|
|
}
|
|
}
|