diff --git a/internal/util/function/models/ali/ali_dashscope_client.go b/internal/util/function/models/ali/ali_dashscope_client.go index 5421f319d9..233682c3ef 100644 --- a/internal/util/function/models/ali/ali_dashscope_client.go +++ b/internal/util/function/models/ali/ali_dashscope_client.go @@ -122,3 +122,59 @@ func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim }) return res, nil } + +type Inputs struct { + Query string `json:"query"` + Documents []string `json:"documents"` +} +type RerankRequest struct { + // ID of the model to use. + Model string `json:"model"` + + // Input text to embed, encoded as a string. + Inputs Inputs `json:"input"` +} + +type RerankItem struct { + Index int `json:"index"` + RelevanceScore float32 `json:"relevance_score"` +} + +type RerankOutput struct { + Results []RerankItem `json:"results"` +} + +type RerankResponse struct { + Output RerankOutput `json:"output"` + Usage Usage `json:"usage"` + RequestID string `json:"request_id"` +} + +type AliDashScopeRerank struct { + apiKey string +} + +func NewAliDashScopeRerank(apiKey string) *AliDashScopeRerank { + return &AliDashScopeRerank{ + apiKey: apiKey, + } +} + +func (c *AliDashScopeRerank) Rerank(url string, modelName string, query string, texts []string, params map[string]any, timeoutSec int64) (*RerankResponse, error) { + var r RerankRequest + r.Model = modelName + r.Inputs = Inputs{query, texts} + + headers := map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.apiKey), + } + res, err := models.PostRequest[RerankResponse](r, url, headers, timeoutSec) + if err != nil { + return nil, err + } + sort.Slice(res.Output.Results, func(i, j int) bool { + return res.Output.Results[i].Index < res.Output.Results[j].Index + }) + return res, nil +} diff --git a/internal/util/function/models/ali/ali_dashscope_client_test.go b/internal/util/function/models/ali/ali_dashscope_client_test.go index 6fb7cd04e5..f3b4309ab0 100644 --- a/internal/util/function/models/ali/ali_dashscope_client_test.go +++ b/internal/util/function/models/ali/ali_dashscope_client_test.go @@ -114,3 +114,44 @@ func TestEmbeddingFailed(t *testing.T) { assert.True(t, err != nil) } } + +func TestRerankOK(t *testing.T) { + repStr := `{"output":{"results":[{"index":0,"relevance_score":0},{"index":1,"relevance_score":0.1},{"index":2,"relevance_score":0.2}]},"usage":{"total_tokens":1},"request_id":"x"}` + var res RerankResponse + err := json.Unmarshal([]byte(repStr), &res) + assert.NoError(t, err) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data, _ := json.Marshal(res) + w.Write(data) + })) + defer ts.Close() + url := ts.URL + + { + c := NewAliDashScopeRerank("mock_key") + r, err := c.Rerank(url, "gte-rerank-v2", "query", []string{"t1", "t2", "t3"}, nil, 0) + assert.True(t, err == nil) + assert.Equal(t, r.Output.Results[0].Index, 0) + assert.Equal(t, r.Output.Results[1].Index, 1) + assert.Equal(t, r.Output.Results[2].Index, 2) + assert.Equal(t, r.Output.Results[0].RelevanceScore, float32(0.0)) + assert.Equal(t, r.Output.Results[1].RelevanceScore, float32(0.1)) + assert.Equal(t, r.Output.Results[2].RelevanceScore, float32(0.2)) + } +} + +func TestRerankFailed(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + + defer ts.Close() + url := ts.URL + + { + c := NewAliDashScopeRerank("mock_key") + _, err := c.Rerank(url, "gte-rerank-v2", "query", []string{"t1", "t2", "t3"}, nil, 0) + assert.True(t, err != nil) + } +} diff --git a/internal/util/function/rerank/ali_rerank_provider.go b/internal/util/function/rerank/ali_rerank_provider.go new file mode 100644 index 0000000000..4861c7f366 --- /dev/null +++ b/internal/util/function/rerank/ali_rerank_provider.go @@ -0,0 +1,86 @@ +/* + * # Licensed to the LF AI & Data foundation under one + * # or more contributor license agreements. See the NOTICE file + * # distributed with this work for additional information + * # regarding copyright ownership. The ASF licenses this file + * # to you under the Apache License, Version 2.0 (the + * # "License"); you may not use this file except in compliance + * # with the License. You may obtain a copy of the License at + * # + * # http://www.apache.org/licenses/LICENSE-2.0 + * # + * # Unless required by applicable law or agreed to in writing, software + * # distributed under the License is distributed on an "AS IS" BASIS, + * # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * # See the License for the specific language governing permissions and + * # limitations under the License. + */ + +package rerank + +import ( + "context" + "fmt" + "strings" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus/internal/util/credentials" + "github.com/milvus-io/milvus/internal/util/function/models" + "github.com/milvus-io/milvus/internal/util/function/models/ali" +) + +type aliProvider struct { + baseProvider + client *ali.AliDashScopeRerank + url string + modelName string + params map[string]any +} + +func newAliProvider(params []*commonpb.KeyValuePair, conf map[string]string, credentials *credentials.Credentials) (modelProvider, error) { + apiKey, url, err := models.ParseAKAndURL(credentials, params, conf, models.DashscopeAKEnvStr) + if err != nil { + return nil, err + } + if url == "" { + url = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank" + } + client := ali.NewAliDashScopeRerank(apiKey) + var modelName string + truncateParams := map[string]any{} + maxBatch := 128 + for _, param := range params { + switch strings.ToLower(param.Key) { + case models.ModelNameParamKey: + modelName = param.Value + case models.MaxClientBatchSizeParamKey: + if maxBatch, err = parseMaxBatch(param.Value); err != nil { + return nil, err + } + default: + } + } + if modelName == "" { + return nil, fmt.Errorf("ali rerank model name is required") + } + provider := aliProvider{ + baseProvider: baseProvider{batchSize: maxBatch}, + client: client, + url: url, + modelName: modelName, + params: truncateParams, + } + return &provider, nil +} + +func (provider *aliProvider) rerank(ctx context.Context, query string, docs []string) ([]float32, error) { + rerankResp, err := provider.client.Rerank(provider.url, provider.modelName, query, docs, provider.params, 30) + if err != nil { + return nil, err + } + scores := make([]float32, len(docs)) + for i, rerankResult := range rerankResp.Output.Results { + scores[i] = rerankResult.RelevanceScore + } + return scores, nil +} diff --git a/internal/util/function/rerank/model_function.go b/internal/util/function/rerank/model_function.go index 6737f11d5c..a543c22626 100644 --- a/internal/util/function/rerank/model_function.go +++ b/internal/util/function/rerank/model_function.go @@ -39,6 +39,7 @@ const ( siliconflowProviderName string = "siliconflow" cohereProviderName string = "cohere" voyageaiProviderName string = "voyageai" + aliProviderName string = "ali" queryKeyName string = "queries" ) @@ -84,6 +85,8 @@ func newProvider(params []*commonpb.KeyValuePair) (modelProvider, error) { return newCohereProvider(params, conf, credentials) case voyageaiProviderName: return newVoyageaiProvider(params, conf, credentials) + case aliProviderName: + return newAliProvider(params, conf, credentials) default: return nil, fmt.Errorf("Unknow rerank model provider:%s", param.Value) } diff --git a/internal/util/function/rerank/model_function_test.go b/internal/util/function/rerank/model_function_test.go index eb22508416..69c856a470 100644 --- a/internal/util/function/rerank/model_function_test.go +++ b/internal/util/function/rerank/model_function_test.go @@ -250,6 +250,24 @@ func (s *RerankModelSuite) TestNewProvider() { _, err := newProvider(params) s.NoError(err) } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "ali"}, + {Key: models.CredentialParamKey, Value: "mock"}, + } + _, err := newProvider(params) + s.ErrorContains(err, "ali rerank model name is required") + } + { + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "ali"}, + {Key: models.ModelNameParamKey, Value: "ali-test"}, + {Key: models.CredentialParamKey, Value: "mock"}, + {Key: models.MaxClientBatchSizeParamKey, Value: "10"}, + } + _, err := newProvider(params) + s.NoError(err) + } } func (s *RerankModelSuite) TestCallVllm() { @@ -404,6 +422,28 @@ func (s *RerankModelSuite) TestCallVoyageAI() { } } +func (s *RerankModelSuite) TestCallAli() { + { + repStr := `{"output":{"results":[{"index":0,"relevance_score":0},{"index":1,"relevance_score":0.1},{"index":2,"relevance_score":0.2}]},"usage":{"total_tokens":1},"request_id":"x"}` + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(repStr)) + })) + defer ts.Close() + params := []*commonpb.KeyValuePair{ + {Key: providerParamName, Value: "ali"}, + {Key: models.ModelNameParamKey, Value: "ali-test"}, + {Key: models.CredentialParamKey, Value: "mock"}, + } + provder, err := newAliProvider(params, map[string]string{models.URLParamKey: ts.URL}, credentials.NewCredentials(map[string]string{"mock.apikey": "mock"})) + s.NoError(err) + scores, err := provder.rerank(context.Background(), "mytest", []string{"t1", "t2", "t3"}) + s.NoError(err) + s.Equal([]float32{0.0, 0.1, 0.2}, scores) + } +} + func (s *RerankModelSuite) TestNewModelFunction() { schema := &schemapb.CollectionSchema{ Name: "test",