junjiejiangjjj f1ce84996d
enhance: refactor model service configuration and environment variables (#44036)
- Add enable configuration for all model service providers
- Migrate environment variables from MILVUSAI_* to MILVUS_* prefix with
backward compatibility
- Unify model service enable/disable logic using configuration
- Add tests for environment variable parsing with fallback support

#35856

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
2025-08-26 10:49:52 +08:00

322 lines
8.0 KiB
Go

/*
* # Licensed to the LF AI & Data foundation under one
* # or more contributor license agreements. See the NOTICE file
* # distributed with this work for additional information
* # regarding copyright ownership. The ASF licenses this file
* # to you under the Apache License, Version 2.0 (the
* # "License"); you may not use this file except in compliance
* # with the License. You may obtain a copy of the License at
* #
* # http://www.apache.org/licenses/LICENSE-2.0
* #
* # Unless required by applicable law or agreed to in writing, software
* # distributed under the License is distributed on an "AS IS" BASIS,
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* # See the License for the specific language governing permissions and
* # limitations under the License.
*/
package models
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/util/credentials"
)
type TextEmbeddingMode int
const (
InsertMode TextEmbeddingMode = iota
SearchMode
)
type EmbeddingType int
const (
UnsupportEmbd EmbeddingType = iota
Float32Embd
Int8Embd
)
const EnableConf string = "enable"
// common params
const (
ModelNameParamKey string = "model_name"
DimParamKey string = "dim"
URLParamKey string = "url"
CredentialParamKey string = "credential"
TruncateParamKey string = "truncate"
MaxClientBatchSizeParamKey string = "max_client_batch_size"
)
// ali text embedding
const (
DashscopeAKEnvStr string = "MILVUS_DASHSCOPE_API_KEY"
)
// openai/azure text embedding
const (
OpenaiAKEnvStr string = "MILVUS_OPENAI_API_KEY"
AzureOpenaiAKEnvStr string = "MILVUS_AZURE_OPENAI_API_KEY"
AzureOpenaiResourceName string = "MILVUS_AZURE_OPENAI_RESOURCE_NAME"
UserParamKey string = "user"
)
// bedrock emebdding
const (
// awsAKIdParamKey string = "aws_access_key_id"
// awsSAKParamKey string = "aws_secret_access_key"
RegionParamKey string = "region"
NormalizeParamKey string = "normalize"
BedrockAccessKeyId string = "MILVUS_BEDROCK_ACCESS_KEY_ID"
BedrockSAKEnvStr string = "MILVUS_BEDROCK_SECRET_ACCESS_KEY"
)
// vertexAI
const (
LocationParamKey string = "location"
ProjectIDParamKey string = "projectid"
TaskTypeParamKey string = "task"
VertexServiceAccountJSONEnv string = "MILVUS_GOOGLE_APPLICATION_CREDENTIALS"
)
// voyageAI
const (
TruncationParamKey string = "truncation"
VoyageAIAKEnvStr string = "MILVUS_VOYAGEAI_API_KEY"
)
// cohere
const (
MaxTKsPerDocParamKey string = "max_tokens_per_doc"
CohereAIAKEnvStr string = "MILVUS_COHERE_API_KEY"
)
// siliconflow
const (
MaxChunksPerDocParamKey string = "max_chunks_per_doc"
OverlapTokensParamKey string = "overlap_tokens"
SiliconflowAKEnvStr string = "MILVUS_SILICONFLOW_API_KEY"
)
// TEI and vllm
const (
IngestionPromptParamKey string = "ingestion_prompt"
SearchPromptParamKey string = "search_prompt"
TruncationDirectionParamKey string = "truncation_direction"
EndpointParamKey string = "endpoint"
VllmTruncateParamName string = "truncate_prompt_tokens"
TeiTruncateParamName string = "truncate"
)
func ParseAKAndURL(credentials *credentials.Credentials, params []*commonpb.KeyValuePair, confParams map[string]string, apiKeyEnv string) (string, string, error) {
// function param > yaml > env
var err error
var apiKey, url string
for _, param := range params {
switch strings.ToLower(param.Key) {
case CredentialParamKey:
credentialName := param.Value
if apiKey, err = credentials.GetAPIKeyCredential(credentialName); err != nil {
return "", "", err
}
}
}
// from milvus.yaml
if apiKey == "" {
credentialName := confParams[CredentialParamKey]
if credentialName != "" {
if apiKey, err = credentials.GetAPIKeyCredential(credentialName); err != nil {
return "", "", err
}
}
}
if url == "" {
url = confParams[URLParamKey]
}
// from env, url doesn't support configuration in in env
if apiKey == "" {
apiKey = os.Getenv(apiKeyEnv)
}
// DEPRECATED: MILVUSAI_* env variables will be removed in Milvus 3.0.
// Use NEW_ENV_* instead.
if apiKey == "" {
newEnvStr := "MILVUSAI_" + strings.TrimPrefix(apiKeyEnv, "MILVUS_")
apiKey = os.Getenv(newEnvStr)
}
return apiKey, url, nil
}
func ParseAndCheckFieldDim(dimStr string, fieldDim int64, fieldName string) (int64, error) {
dim, err := strconv.ParseInt(dimStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("dimension [%s] provided in Function params is not a valid int", dimStr)
}
if dim != 0 && dim != fieldDim {
return 0, fmt.Errorf("Function output field:[%s]'s dimension [%d] does not match the dimension [%d] provided in Function params.", fieldName, fieldDim, dim)
}
return dim, nil
}
func GetEmbdType(dtype schemapb.DataType) EmbeddingType {
switch dtype {
case schemapb.DataType_FloatVector:
return Float32Embd
case schemapb.DataType_Int8Vector:
return Int8Embd
default:
return UnsupportEmbd
}
}
type EmbdResult struct {
FloatEmbds [][]float32
Int8Embds [][]int8
EmbdType EmbeddingType
}
func NewEmbdResult(size int, eType EmbeddingType) *EmbdResult {
embR := EmbdResult{}
embR.EmbdType = eType
if eType == Float32Embd {
embR.FloatEmbds = make([][]float32, 0, size)
} else {
embR.Int8Embds = make([][]int8, 0, size)
}
return &embR
}
func (embR *EmbdResult) Append(newEmbd any) {
switch newEmbd := newEmbd.(type) {
case [][]float32:
embR.FloatEmbds = append(embR.FloatEmbds, newEmbd...)
case []float32:
embR.FloatEmbds = append(embR.FloatEmbds, newEmbd)
case [][]int8:
embR.Int8Embds = append(embR.Int8Embds, newEmbd...)
case []int8:
embR.Int8Embds = append(embR.Int8Embds, newEmbd)
}
}
func NewBaseURL(endpoint string) (*url.URL, error) {
base, err := url.Parse(endpoint)
if err != nil {
return nil, err
}
if base.Scheme != "http" && base.Scheme != "https" {
return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint)
}
if base.Host == "" {
return nil, fmt.Errorf("endpoint: [%s] is not a valid http/https link", endpoint)
}
return base, nil
}
func IsEnable(conf map[string]string) bool {
// milvus.yaml
enableConf := conf[EnableConf]
return strings.ToLower(enableConf) != "false"
}
type Response any
func PostRequest[T Response](req any, url string, headers map[string]string, timeoutSec int64) (*T, error) {
data, err := json.Marshal(req)
if err != nil {
return nil, err
}
if timeoutSec <= 0 {
timeoutSec = 30
}
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
defer cancel()
body, err := retrySend(ctx, data, http.MethodPost, url, headers, 3)
if err != nil {
return nil, err
}
var res T
err = json.Unmarshal(body, &res)
if err != nil {
return nil, fmt.Errorf("Call service failed, unmarshal response failed, errs:[%v]", err)
}
return &res, err
}
func send(req *http.Request) ([]byte, error) {
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("Call service failed, errs:[%v]", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("Call service failed, read response failed, errs:[%v]", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("Call service failed, errs:[%s, %s]", resp.Status, body)
}
return body, nil
}
func retrySend(ctx context.Context, data []byte, httpMethod string, url string, headers map[string]string, maxRetries int) ([]byte, error) {
var err error
var body []byte
for i := 0; i < maxRetries; i++ {
req, reqErr := http.NewRequestWithContext(ctx, httpMethod, url, bytes.NewBuffer(data))
if reqErr != nil {
return nil, reqErr
}
for k, v := range headers {
req.Header.Set(k, v)
}
body, err = send(req)
if err == nil {
return body, nil
}
backoffDelay := 1 << uint(i) * time.Second
jitter := time.Duration(rand.Int63n(int64(backoffDelay / 4)))
time.Sleep(backoffDelay + jitter)
}
return nil, err
}