milvus/internal/util/function/models/ali/ali_dashscope_client.go
PjJinchen d8efe8a6fb
feat: support ali qwen rerank (#44363) (#44364)
issue: #44363

---------

Signed-off-by: PjJinchen <6268414+pj1987111@users.noreply.github.com>
2025-09-30 23:19:52 +08:00

181 lines
4.6 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ali
import (
"fmt"
"sort"
"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus/internal/util/function/models"
)
type Input struct {
Texts []string `json:"texts"`
}
type Parameters struct {
TextType string `json:"text_type,omitempty"`
Dimension int `json:"dimension,omitempty"`
OutputType string `json:"output_type,omitempty"`
}
type EmbeddingRequest struct {
// ID of the model to use.
Model string `json:"model"`
// Input text to embed, encoded as a string.
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type Usage struct {
// The total number of tokens used by the request.
TotalTokens int `json:"total_tokens"`
}
type SparseEmbedding struct {
Index int `json:"index"`
Value float32 `json:"value"`
Token string `json:"token"`
}
type Embeddings struct {
TextIndex int `json:"text_index"`
Embedding []float32 `json:"embedding,omitempty"`
SparseEmbedding []SparseEmbedding `json:"sparse_embedding,omitempty"`
}
type Output struct {
Embeddings []Embeddings `json:"embeddings"`
}
type EmbeddingResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
RequestID string `json:"request_id"`
}
type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
RequestID string `json:"request_id"`
}
type AliDashScopeEmbedding struct {
apiKey string
url string
}
func NewAliDashScopeEmbeddingClient(apiKey string, url string) *AliDashScopeEmbedding {
return &AliDashScopeEmbedding{
apiKey: apiKey,
url: url,
}
}
func (c *AliDashScopeEmbedding) Check() error {
if c.apiKey == "" {
return errors.New("api key is empty")
}
if c.url == "" {
return errors.New("url is empty")
}
return nil
}
func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
var r EmbeddingRequest
r.Model = modelName
r.Input = Input{texts}
r.Parameters.Dimension = dim
r.Parameters.TextType = textType
r.Parameters.OutputType = outputType
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
res, err := models.PostRequest[EmbeddingResponse](r, c.url, headers, timeoutSec)
if err != nil {
return nil, err
}
sort.Slice(res.Output.Embeddings, func(i, j int) bool {
return res.Output.Embeddings[i].TextIndex < res.Output.Embeddings[j].TextIndex
})
return res, nil
}
type Inputs struct {
Query string `json:"query"`
Documents []string `json:"documents"`
}
type RerankRequest struct {
// ID of the model to use.
Model string `json:"model"`
// Input text to embed, encoded as a string.
Inputs Inputs `json:"input"`
}
type RerankItem struct {
Index int `json:"index"`
RelevanceScore float32 `json:"relevance_score"`
}
type RerankOutput struct {
Results []RerankItem `json:"results"`
}
type RerankResponse struct {
Output RerankOutput `json:"output"`
Usage Usage `json:"usage"`
RequestID string `json:"request_id"`
}
type AliDashScopeRerank struct {
apiKey string
}
func NewAliDashScopeRerank(apiKey string) *AliDashScopeRerank {
return &AliDashScopeRerank{
apiKey: apiKey,
}
}
func (c *AliDashScopeRerank) Rerank(url string, modelName string, query string, texts []string, params map[string]any, timeoutSec int64) (*RerankResponse, error) {
var r RerankRequest
r.Model = modelName
r.Inputs = Inputs{query, texts}
headers := map[string]string{
"Content-Type": "application/json",
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
}
res, err := models.PostRequest[RerankResponse](r, url, headers, timeoutSec)
if err != nil {
return nil, err
}
sort.Slice(res.Output.Results, func(i, j int) bool {
return res.Output.Results[i].Index < res.Output.Results[j].Index
})
return res, nil
}