mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
issue: #44363 --------- Signed-off-by: PjJinchen <6268414+pj1987111@users.noreply.github.com>
181 lines
4.6 KiB
Go
181 lines
4.6 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package ali
|
|
|
|
import (
|
|
"fmt"
|
|
"sort"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
|
|
"github.com/milvus-io/milvus/internal/util/function/models"
|
|
)
|
|
|
|
type Input struct {
|
|
Texts []string `json:"texts"`
|
|
}
|
|
|
|
type Parameters struct {
|
|
TextType string `json:"text_type,omitempty"`
|
|
Dimension int `json:"dimension,omitempty"`
|
|
OutputType string `json:"output_type,omitempty"`
|
|
}
|
|
|
|
type EmbeddingRequest struct {
|
|
// ID of the model to use.
|
|
Model string `json:"model"`
|
|
|
|
// Input text to embed, encoded as a string.
|
|
Input Input `json:"input"`
|
|
|
|
Parameters Parameters `json:"parameters,omitempty"`
|
|
}
|
|
|
|
type Usage struct {
|
|
// The total number of tokens used by the request.
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type SparseEmbedding struct {
|
|
Index int `json:"index"`
|
|
Value float32 `json:"value"`
|
|
Token string `json:"token"`
|
|
}
|
|
|
|
type Embeddings struct {
|
|
TextIndex int `json:"text_index"`
|
|
Embedding []float32 `json:"embedding,omitempty"`
|
|
SparseEmbedding []SparseEmbedding `json:"sparse_embedding,omitempty"`
|
|
}
|
|
|
|
type Output struct {
|
|
Embeddings []Embeddings `json:"embeddings"`
|
|
}
|
|
|
|
type EmbeddingResponse struct {
|
|
Output Output `json:"output"`
|
|
Usage Usage `json:"usage"`
|
|
RequestID string `json:"request_id"`
|
|
}
|
|
|
|
type ErrorInfo struct {
|
|
Code string `json:"code"`
|
|
Message string `json:"message"`
|
|
RequestID string `json:"request_id"`
|
|
}
|
|
|
|
type AliDashScopeEmbedding struct {
|
|
apiKey string
|
|
url string
|
|
}
|
|
|
|
func NewAliDashScopeEmbeddingClient(apiKey string, url string) *AliDashScopeEmbedding {
|
|
return &AliDashScopeEmbedding{
|
|
apiKey: apiKey,
|
|
url: url,
|
|
}
|
|
}
|
|
|
|
func (c *AliDashScopeEmbedding) Check() error {
|
|
if c.apiKey == "" {
|
|
return errors.New("api key is empty")
|
|
}
|
|
|
|
if c.url == "" {
|
|
return errors.New("url is empty")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *AliDashScopeEmbedding) Embedding(modelName string, texts []string, dim int, textType string, outputType string, timeoutSec int64) (*EmbeddingResponse, error) {
|
|
var r EmbeddingRequest
|
|
r.Model = modelName
|
|
r.Input = Input{texts}
|
|
r.Parameters.Dimension = dim
|
|
r.Parameters.TextType = textType
|
|
r.Parameters.OutputType = outputType
|
|
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
|
}
|
|
res, err := models.PostRequest[EmbeddingResponse](r, c.url, headers, timeoutSec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sort.Slice(res.Output.Embeddings, func(i, j int) bool {
|
|
return res.Output.Embeddings[i].TextIndex < res.Output.Embeddings[j].TextIndex
|
|
})
|
|
return res, nil
|
|
}
|
|
|
|
type Inputs struct {
|
|
Query string `json:"query"`
|
|
Documents []string `json:"documents"`
|
|
}
|
|
type RerankRequest struct {
|
|
// ID of the model to use.
|
|
Model string `json:"model"`
|
|
|
|
// Input text to embed, encoded as a string.
|
|
Inputs Inputs `json:"input"`
|
|
}
|
|
|
|
type RerankItem struct {
|
|
Index int `json:"index"`
|
|
RelevanceScore float32 `json:"relevance_score"`
|
|
}
|
|
|
|
type RerankOutput struct {
|
|
Results []RerankItem `json:"results"`
|
|
}
|
|
|
|
type RerankResponse struct {
|
|
Output RerankOutput `json:"output"`
|
|
Usage Usage `json:"usage"`
|
|
RequestID string `json:"request_id"`
|
|
}
|
|
|
|
type AliDashScopeRerank struct {
|
|
apiKey string
|
|
}
|
|
|
|
func NewAliDashScopeRerank(apiKey string) *AliDashScopeRerank {
|
|
return &AliDashScopeRerank{
|
|
apiKey: apiKey,
|
|
}
|
|
}
|
|
|
|
func (c *AliDashScopeRerank) Rerank(url string, modelName string, query string, texts []string, params map[string]any, timeoutSec int64) (*RerankResponse, error) {
|
|
var r RerankRequest
|
|
r.Model = modelName
|
|
r.Inputs = Inputs{query, texts}
|
|
|
|
headers := map[string]string{
|
|
"Content-Type": "application/json",
|
|
"Authorization": fmt.Sprintf("Bearer %s", c.apiKey),
|
|
}
|
|
res, err := models.PostRequest[RerankResponse](r, url, headers, timeoutSec)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sort.Slice(res.Output.Results, func(i, j int) bool {
|
|
return res.Output.Results[i].Index < res.Output.Results[j].Index
|
|
})
|
|
return res, nil
|
|
}
|