mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-08 18:18:30 +08:00
issue: https://github.com/milvus-io/milvus/issues/39818 This PR mimics Varchar data type, allows insert, search, query, delete, full-text search and others. Functionalities related to filter expressions are disabled temporarily. Storage changes for Text data type will be in the following PRs. Signed-off-by: Patrick Weizhi Xu <weizhi.xu@zilliz.com>
215 lines
7.4 KiB
Go
215 lines
7.4 KiB
Go
/*
|
|
* # Licensed to the LF AI & Data foundation under one
|
|
* # or more contributor license agreements. See the NOTICE file
|
|
* # distributed with this work for additional information
|
|
* # regarding copyright ownership. The ASF licenses this file
|
|
* # to you under the Apache License, Version 2.0 (the
|
|
* # "License"); you may not use this file except in compliance
|
|
* # with the License. You may obtain a copy of the License at
|
|
* #
|
|
* # http://www.apache.org/licenses/LICENSE-2.0
|
|
* #
|
|
* # Unless required by applicable law or agreed to in writing, software
|
|
* # distributed under the License is distributed on an "AS IS" BASIS,
|
|
* # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* # See the License for the specific language governing permissions and
|
|
* # limitations under the License.
|
|
*/
|
|
|
|
package function
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/storage"
|
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
|
)
|
|
|
|
const (
|
|
Provider string = "provider"
|
|
)
|
|
|
|
const (
|
|
openAIProvider string = "openai"
|
|
azureOpenAIProvider string = "azure_openai"
|
|
aliDashScopeProvider string = "dashscope"
|
|
bedrockProvider string = "bedrock"
|
|
vertexAIProvider string = "vertexai"
|
|
voyageAIProvider string = "voyageai"
|
|
cohereProvider string = "cohere"
|
|
)
|
|
|
|
// Text embedding for retrieval task
|
|
type textEmbeddingProvider interface {
|
|
MaxBatch() int
|
|
CallEmbedding(texts []string, mode TextEmbeddingMode) ([][]float32, error)
|
|
FieldDim() int64
|
|
}
|
|
|
|
func getProvider(functionSchema *schemapb.FunctionSchema) (string, error) {
|
|
for _, param := range functionSchema.Params {
|
|
switch strings.ToLower(param.Key) {
|
|
case Provider:
|
|
return strings.ToLower(param.Value), nil
|
|
default:
|
|
}
|
|
}
|
|
return "", fmt.Errorf("The text embedding service provider parameter:[%s] was not found", Provider)
|
|
}
|
|
|
|
type TextEmbeddingFunction struct {
|
|
FunctionBase
|
|
|
|
embProvider textEmbeddingProvider
|
|
}
|
|
|
|
func isValidInputDataType(dataType schemapb.DataType) bool {
|
|
return dataType == schemapb.DataType_VarChar || dataType == schemapb.DataType_Text
|
|
}
|
|
|
|
func NewTextEmbeddingFunction(coll *schemapb.CollectionSchema, functionSchema *schemapb.FunctionSchema) (*TextEmbeddingFunction, error) {
|
|
if len(functionSchema.GetOutputFieldNames()) != 1 {
|
|
return nil, fmt.Errorf("Text function should only have one output field, but now is %d", len(functionSchema.GetOutputFieldNames()))
|
|
}
|
|
|
|
base, err := NewFunctionBase(coll, functionSchema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if base.outputFields[0].DataType != schemapb.DataType_FloatVector {
|
|
return nil, fmt.Errorf("Text embedding function's output field not match, needs [%s], got [%s]",
|
|
schemapb.DataType_name[int32(schemapb.DataType_FloatVector)],
|
|
schemapb.DataType_name[int32(base.outputFields[0].DataType)])
|
|
}
|
|
|
|
provider, err := getProvider(functionSchema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var embP textEmbeddingProvider
|
|
var newProviderErr error
|
|
switch provider {
|
|
case openAIProvider:
|
|
embP, newProviderErr = NewOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
|
case azureOpenAIProvider:
|
|
embP, newProviderErr = NewAzureOpenAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
|
case bedrockProvider:
|
|
embP, newProviderErr = NewBedrockEmbeddingProvider(base.outputFields[0], functionSchema, nil)
|
|
case aliDashScopeProvider:
|
|
embP, newProviderErr = NewAliDashScopeEmbeddingProvider(base.outputFields[0], functionSchema)
|
|
case vertexAIProvider:
|
|
embP, newProviderErr = NewVertexAIEmbeddingProvider(base.outputFields[0], functionSchema, nil)
|
|
case voyageAIProvider:
|
|
embP, newProviderErr = NewVoyageAIEmbeddingProvider(base.outputFields[0], functionSchema)
|
|
case cohereProvider:
|
|
embP, newProviderErr = NewCohereEmbeddingProvider(base.outputFields[0], functionSchema)
|
|
default:
|
|
return nil, fmt.Errorf("Unsupported text embedding service provider: [%s] , list of supported [%s, %s, %s, %s, %s, %s, %s]", provider, openAIProvider, azureOpenAIProvider, aliDashScopeProvider, bedrockProvider, vertexAIProvider, voyageAIProvider, cohereProvider)
|
|
}
|
|
|
|
if newProviderErr != nil {
|
|
return nil, newProviderErr
|
|
}
|
|
return &TextEmbeddingFunction{
|
|
FunctionBase: *base,
|
|
embProvider: embP,
|
|
}, nil
|
|
}
|
|
|
|
func (runner *TextEmbeddingFunction) MaxBatch() int {
|
|
return runner.embProvider.MaxBatch()
|
|
}
|
|
|
|
func (runner *TextEmbeddingFunction) ProcessInsert(inputs []*schemapb.FieldData) ([]*schemapb.FieldData, error) {
|
|
if len(inputs) != 1 {
|
|
return nil, fmt.Errorf("Text embedding function only receives one input field, but got [%d]", len(inputs))
|
|
}
|
|
|
|
if !isValidInputDataType(inputs[0].Type) {
|
|
return nil, fmt.Errorf("Text embedding only supports varchar or text field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].Type)])
|
|
}
|
|
|
|
texts := inputs[0].GetScalars().GetStringData().GetData()
|
|
if texts == nil {
|
|
return nil, fmt.Errorf("Input texts is empty")
|
|
}
|
|
numRows := len(texts)
|
|
if numRows > runner.MaxBatch() {
|
|
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
|
}
|
|
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim()))
|
|
for _, emb := range embds {
|
|
data = append(data, emb...)
|
|
}
|
|
|
|
var outputField schemapb.FieldData
|
|
outputField.FieldId = runner.outputFields[0].FieldID
|
|
outputField.FieldName = runner.outputFields[0].Name
|
|
outputField.Type = runner.outputFields[0].DataType
|
|
outputField.IsDynamic = runner.outputFields[0].IsDynamic
|
|
outputField.Field = &schemapb.FieldData_Vectors{
|
|
Vectors: &schemapb.VectorField{
|
|
Data: &schemapb.VectorField_FloatVector{
|
|
FloatVector: &schemapb.FloatArray{
|
|
Data: data,
|
|
},
|
|
},
|
|
Dim: runner.embProvider.FieldDim(),
|
|
},
|
|
}
|
|
return []*schemapb.FieldData{&outputField}, nil
|
|
}
|
|
|
|
func (runner *TextEmbeddingFunction) ProcessSearch(placeholderGroup *commonpb.PlaceholderGroup) (*commonpb.PlaceholderGroup, error) {
|
|
texts := funcutil.GetVarCharFromPlaceholder(placeholderGroup.Placeholders[0]) // Already checked externally
|
|
numRows := len(texts)
|
|
if numRows > runner.MaxBatch() {
|
|
return nil, fmt.Errorf("Embedding supports up to [%d] pieces of data at a time, got [%d]", runner.MaxBatch(), numRows)
|
|
}
|
|
embds, err := runner.embProvider.CallEmbedding(texts, SearchMode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return funcutil.Float32VectorsToPlaceholderGroup(embds), nil
|
|
}
|
|
|
|
func (runner *TextEmbeddingFunction) ProcessBulkInsert(inputs []storage.FieldData) (map[storage.FieldID]storage.FieldData, error) {
|
|
if len(inputs) != 1 {
|
|
return nil, fmt.Errorf("TextEmbedding function only receives one input, bug got [%d]", len(inputs))
|
|
}
|
|
|
|
if !isValidInputDataType(inputs[0].GetDataType()) {
|
|
return nil, fmt.Errorf("TextEmbedding function only supports varchar or text field as input field, but got %s", schemapb.DataType_name[int32(inputs[0].GetDataType())])
|
|
}
|
|
|
|
texts, ok := inputs[0].GetDataRows().([]string)
|
|
if !ok {
|
|
return nil, fmt.Errorf("Input texts is empty")
|
|
}
|
|
|
|
embds, err := runner.embProvider.CallEmbedding(texts, InsertMode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
data := make([]float32, 0, len(texts)*int(runner.embProvider.FieldDim()))
|
|
for _, emb := range embds {
|
|
data = append(data, emb...)
|
|
}
|
|
|
|
field := &storage.FloatVectorFieldData{
|
|
Data: data,
|
|
Dim: int(runner.embProvider.FieldDim()),
|
|
}
|
|
return map[storage.FieldID]storage.FieldData{
|
|
runner.outputFields[0].FieldID: field,
|
|
}, nil
|
|
}
|