congqixia f5f053f1d2
enhance: Refactor privilege management by extracting privilege cache into separate package (#44762)
Related to #44761

This commit refactors the privilege management system in the proxy
component by:

1. **Separation of Concerns**: Extracts privilege-related functionality
from MetaCache into a dedicated `internal/proxy/privilege` package,
improving code organization and maintainability.

2. **New Package Structure**: Creates `internal/proxy/privilege/` with:
   - `cache.go`: Core privilege cache implementation (PrivilegeCache)
   - `result_cache.go`: Privilege enforcement result caching
   - `model.go`: Casbin model and policy enforcement functions
   - `meta_cache_adapter.go`: Casbin adapter for MetaCache integration
   - Corresponding test files and mock implementations

3. **MetaCache Simplification**: Removes privilege and credential
management methods from MetaCache interface and implementation:
   - Removed: GetCredentialInfo, RemoveCredential, UpdateCredential
- Removed: GetPrivilegeInfo, GetUserRole, RefreshPolicyInfo,
InitPolicyInfo
   - Deleted: meta_cache_adapter.go, privilege_cache.go and their tests

4. **Updated References**: Updates all callsites to use the new
privilegeCache global:
- Authentication interceptor now uses privilegeCache for password
verification
- Credential cache operations (InvalidateCredentialCache,
UpdateCredentialCache, UpdateCredential) now use privilegeCache
- Policy refresh operations (RefreshPolicyInfoCache) now use
privilegeCache
- Privilege interceptor uses new privilege.GetEnforcer() and privilege
result cache

5. **Improved API**: Renames cache functions for clarity:
   - GetPrivilegeCache → GetResultCache
   - SetPrivilegeCache → SetResultCache
   - CleanPrivilegeCache → CleanResultCache

This refactoring makes the codebase more modular, separates privilege
management concerns from general metadata caching, and provides a
clearer API for privilege enforcement operations.

---------

Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
2025-10-13 11:15:58 +08:00

3142 lines
106 KiB
Go

// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"fmt"
"reflect"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.opentelemetry.io/otel"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/json"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/proxy/privilege"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/analyzer"
"github.com/milvus-io/milvus/internal/util/function/embedding"
"github.com/milvus-io/milvus/internal/util/hookutil"
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
"github.com/milvus-io/milvus/internal/util/segcore"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/metric"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const (
strongTS = 0
boundedTS = 2
// enableMultipleVectorFields indicates whether to enable multiple vector fields.
enableMultipleVectorFields = true
defaultMaxArrayCapacity = 4096
defaultMaxSearchRequest = 1024
// DefaultArithmeticIndexType name of default index type for scalar field
DefaultArithmeticIndexType = indexparamcheck.IndexINVERTED
// DefaultStringIndexType name of default index type for varChar/string field
DefaultStringIndexType = indexparamcheck.IndexINVERTED
)
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))
func ConcatStructFieldName(structName string, fieldName string) string {
return fmt.Sprintf("%s[%s]", structName, fieldName)
}
// transformStructFieldNames transforms struct field names to structName[fieldName] format
// This ensures global uniqueness while allowing same field names across different structs
func transformStructFieldNames(schema *schemapb.CollectionSchema) error {
for _, structArrayField := range schema.StructArrayFields {
structName := structArrayField.Name
for _, field := range structArrayField.Fields {
// Create transformed name: structName[fieldName]
newName := ConcatStructFieldName(structName, field.Name)
field.Name = newName
}
}
return nil
}
// restoreStructFieldNames restores original field names from structName[fieldName] format
// This is used when returning schema information to users (e.g., in describe collection)
func restoreStructFieldNames(schema *schemapb.CollectionSchema) error {
for _, structArrayField := range schema.StructArrayFields {
structName := structArrayField.Name
expectedPrefix := structName + "["
for _, field := range structArrayField.Fields {
if strings.HasPrefix(field.Name, expectedPrefix) && strings.HasSuffix(field.Name, "]") {
// Extract fieldName: remove "structName[" prefix and "]" suffix
field.Name = field.Name[len(expectedPrefix) : len(field.Name)-1]
}
}
}
return nil
}
// extractOriginalFieldName extracts the original field name from structName[fieldName] format
// This function should only be called on transformed struct field names
func extractOriginalFieldName(transformedName string) (string, error) {
idx := strings.Index(transformedName, "[")
if idx == -1 {
return "", fmt.Errorf("not a transformed struct field name: %s", transformedName)
}
if !strings.HasSuffix(transformedName, "]") {
return "", fmt.Errorf("invalid struct field format: %s, missing closing bracket", transformedName)
}
if idx == 0 {
return "", fmt.Errorf("invalid struct field format: %s, missing struct name", transformedName)
}
fieldName := transformedName[idx+1 : len(transformedName)-1]
if fieldName == "" {
return "", fmt.Errorf("invalid struct field format: %s, empty field name", transformedName)
}
return fieldName, nil
}
// isAlpha check if c is alpha.
func isAlpha(c uint8) bool {
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') {
return false
}
return true
}
// isNumber check if c is a number.
func isNumber(c uint8) bool {
if c < '0' || c > '9' {
return false
}
return true
}
// check run analyzer params when collection name was set
func validateRunAnalyzer(req *milvuspb.RunAnalyzerRequest) error {
if req.GetAnalyzerParams() != "" {
return fmt.Errorf("run analyzer can't use analyzer params and (collection,field) in same time")
}
if req.GetFieldName() == "" {
return fmt.Errorf("must set field name when collection name was set")
}
if req.GetAnalyzerNames() != nil {
if len(req.GetAnalyzerNames()) != 1 && len(req.GetAnalyzerNames()) != len(req.GetPlaceholder()) {
return fmt.Errorf("only support set one analyzer name for all text or set analyzer name for each text, but now analzer name num: %d, text num: %d",
len(req.GetAnalyzerNames()), len(req.GetPlaceholder()))
}
}
return nil
}
func validateMaxQueryResultWindow(offset int64, limit int64) error {
if offset < 0 {
return fmt.Errorf("%s [%d] is invalid, should be gte than 0", OffsetKey, offset)
}
if limit <= 0 {
return fmt.Errorf("%s [%d] is invalid, should be greater than 0", LimitKey, limit)
}
depth := offset + limit
maxQueryResultWindow := Params.QuotaConfig.MaxQueryResultWindow.GetAsInt64()
if depth <= 0 || depth > maxQueryResultWindow {
return fmt.Errorf("(offset+limit) should be in range [1, %d], but got %d", maxQueryResultWindow, depth)
}
return nil
}
func validateLimit(limit int64) error {
topKLimit := Params.QuotaConfig.TopKLimit.GetAsInt64()
if limit <= 0 || limit > topKLimit {
return fmt.Errorf("it should be in range [1, %d], but got %d", topKLimit, limit)
}
return nil
}
func validateNQLimit(limit int64) error {
nqLimit := Params.QuotaConfig.NQLimit.GetAsInt64()
if limit <= 0 || limit > nqLimit {
return fmt.Errorf("nq (number of search vector per search request) should be in range [1, %d], but got %d", nqLimit, limit)
}
return nil
}
func validateCollectionNameOrAlias(entity, entityType string) error {
if entity == "" {
return merr.WrapErrParameterInvalidMsg("collection %s should not be empty", entityType)
}
invalidMsg := fmt.Sprintf("Invalid collection %s: %s. ", entityType, entity)
if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
return merr.WrapErrParameterInvalidMsg("%s the length of a collection %s must be less than %s characters", invalidMsg, entityType,
Params.ProxyCfg.MaxNameLength.GetValue())
}
firstChar := entity[0]
if firstChar != '_' && !isAlpha(firstChar) {
return merr.WrapErrParameterInvalidMsg("%s the first character of a collection %s must be an underscore or letter", invalidMsg, entityType)
}
for i := 1; i < len(entity); i++ {
c := entity[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
return merr.WrapErrParameterInvalidMsg("%s collection %s can only contain numbers, letters and underscores", invalidMsg, entityType)
}
}
return nil
}
func ValidatePrivilegeGroupName(groupName string) error {
if groupName == "" {
return merr.WrapErrPrivilegeGroupNameInvalid("privilege group name should not be empty")
}
if len(groupName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
return merr.WrapErrPrivilegeGroupNameInvalid(
"the length of a privilege group name %s must be less than %s characters", groupName, Params.ProxyCfg.MaxNameLength.GetValue())
}
firstChar := groupName[0]
if firstChar != '_' && !isAlpha(firstChar) {
return merr.WrapErrPrivilegeGroupNameInvalid(
"the first character of a privilege group name %s must be an underscore or letter", groupName)
}
for i := 1; i < len(groupName); i++ {
c := groupName[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
return merr.WrapErrParameterInvalidMsg(
"privilege group name %s can only contain numbers, letters and underscores", groupName)
}
}
return nil
}
func ValidateResourceGroupName(entity string) error {
if entity == "" {
return errors.New("resource group name couldn't be empty")
}
invalidMsg := fmt.Sprintf("Invalid resource group name %s.", entity)
if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
return merr.WrapErrParameterInvalidMsg("%s the length of a resource group name must be less than %s characters",
invalidMsg, Params.ProxyCfg.MaxNameLength.GetValue())
}
firstChar := entity[0]
if firstChar != '_' && !isAlpha(firstChar) {
return merr.WrapErrParameterInvalidMsg("%s the first character of a resource group name must be an underscore or letter", invalidMsg)
}
for i := 1; i < len(entity); i++ {
c := entity[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
return merr.WrapErrParameterInvalidMsg("%s resource group name can only contain numbers, letters and underscores", invalidMsg)
}
}
return nil
}
func ValidateDatabaseName(dbName string) error {
if dbName == "" {
return merr.WrapErrDatabaseNameInvalid(dbName, "database name couldn't be empty")
}
if len(dbName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
return merr.WrapErrDatabaseNameInvalid(dbName,
fmt.Sprintf("the length of a database name must be less than %d characters", Params.ProxyCfg.MaxNameLength.GetAsInt()))
}
firstChar := dbName[0]
if firstChar != '_' && !isAlpha(firstChar) {
return merr.WrapErrDatabaseNameInvalid(dbName,
"the first character of a database name must be an underscore or letter")
}
for i := 1; i < len(dbName); i++ {
c := dbName[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
return merr.WrapErrDatabaseNameInvalid(dbName,
"database name can only contain numbers, letters and underscores")
}
}
return nil
}
// ValidateCollectionAlias returns true if collAlias is a valid alias name for collection, otherwise returns false.
func ValidateCollectionAlias(collAlias string) error {
return validateCollectionNameOrAlias(collAlias, "alias")
}
func validateCollectionName(collName string) error {
return validateCollectionNameOrAlias(collName, "name")
}
func validatePartitionTag(partitionTag string, strictCheck bool) error {
partitionTag = strings.TrimSpace(partitionTag)
invalidMsg := "Invalid partition name: " + partitionTag + ". "
if partitionTag == "" {
msg := invalidMsg + "Partition name should not be empty."
return errors.New(msg)
}
if len(partitionTag) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
msg := invalidMsg + "The length of a partition name must be less than " + Params.ProxyCfg.MaxNameLength.GetValue() + " characters."
return errors.New(msg)
}
if strictCheck {
firstChar := partitionTag[0]
if firstChar != '_' && !isAlpha(firstChar) && !isNumber(firstChar) {
msg := invalidMsg + "The first character of a partition name must be an underscore or letter."
return errors.New(msg)
}
tagSize := len(partitionTag)
for i := 1; i < tagSize; i++ {
c := partitionTag[i]
if c != '_' && !isAlpha(c) && !isNumber(c) && c != '-' {
msg := invalidMsg + "Partition name can only contain numbers, letters and underscores."
return errors.New(msg)
}
}
}
return nil
}
func validateFieldName(fieldName string) error {
fieldName = strings.TrimSpace(fieldName)
if fieldName == "" {
return merr.WrapErrFieldNameInvalid(fieldName, "field name should not be empty")
}
invalidMsg := "Invalid field name: " + fieldName + ". "
if len(fieldName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
msg := invalidMsg + "The length of a field name must be less than " + Params.ProxyCfg.MaxNameLength.GetValue() + " characters."
return merr.WrapErrFieldNameInvalid(fieldName, msg)
}
firstChar := fieldName[0]
if firstChar != '_' && !isAlpha(firstChar) {
msg := invalidMsg + "The first character of a field name must be an underscore or letter."
return merr.WrapErrFieldNameInvalid(fieldName, msg)
}
fieldNameSize := len(fieldName)
for i := 1; i < fieldNameSize; i++ {
c := fieldName[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
msg := invalidMsg + "Field name can only contain numbers, letters, and underscores."
return merr.WrapErrFieldNameInvalid(fieldName, msg)
}
}
if _, ok := common.FieldNameKeywords[fieldName]; ok {
msg := invalidMsg + fmt.Sprintf("%s is keyword in milvus.", fieldName)
return merr.WrapErrFieldNameInvalid(fieldName, msg)
}
return nil
}
func validateDimension(field *schemapb.FieldSchema) error {
exist := false
var dim int64
for _, param := range field.TypeParams {
if param.Key == common.DimKey {
exist = true
tmp, err := strconv.ParseInt(param.Value, 10, 64)
if err != nil {
return err
}
dim = tmp
break
}
}
// for sparse vector field, dim should not be specified
if typeutil.IsSparseFloatVectorType(field.DataType) {
if exist {
return fmt.Errorf("dim should not be specified for sparse vector field %s(%d)", field.GetName(), field.FieldID)
}
return nil
}
if !exist {
return errors.Newf("dimension is not defined in field type params of field %s, check type param `dim` for vector field", field.GetName())
}
if dim <= 1 {
return fmt.Errorf("invalid dimension: %d. should be in range 2 ~ %d", dim, Params.ProxyCfg.MaxDimension.GetAsInt())
}
// for dense vector field, dim will be limited by max_dimension
if typeutil.IsBinaryVectorType(field.DataType) {
if dim%8 != 0 {
return fmt.Errorf("invalid dimension: %d of field %s. binary vector dimension should be multiple of 8. ", dim, field.GetName())
}
if dim > Params.ProxyCfg.MaxDimension.GetAsInt64()*8 {
return fmt.Errorf("invalid dimension: %d of field %s. binary vector dimension should be in range 2 ~ %d", dim, field.GetName(), Params.ProxyCfg.MaxDimension.GetAsInt()*8)
}
} else {
if dim > Params.ProxyCfg.MaxDimension.GetAsInt64() {
return fmt.Errorf("invalid dimension: %d of field %s. float vector dimension should be in range 2 ~ %d", dim, field.GetName(), Params.ProxyCfg.MaxDimension.GetAsInt())
}
}
return nil
}
func validateMaxLengthPerRow(collectionName string, field *schemapb.FieldSchema) error {
exist := false
for _, param := range field.TypeParams {
if param.Key != common.MaxLengthKey {
continue
}
maxLengthPerRow, err := strconv.ParseInt(param.Value, 10, 64)
if err != nil {
return err
}
var defaultMaxLength int64
if field.DataType == schemapb.DataType_Text {
defaultMaxLength = Params.ProxyCfg.MaxTextLength.GetAsInt64()
} else {
defaultMaxLength = Params.ProxyCfg.MaxVarCharLength.GetAsInt64()
}
if maxLengthPerRow > defaultMaxLength || maxLengthPerRow <= 0 {
return merr.WrapErrParameterInvalidMsg("the maximum length specified for the field(%s) should be in (0, %d], but got %d instead", field.GetName(), defaultMaxLength, maxLengthPerRow)
}
exist = true
}
// if not exist type params max_length, return error
if !exist {
return fmt.Errorf("type param(max_length) should be specified for the field(%s) of collection %s", field.GetName(), collectionName)
}
return nil
}
func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchema) error {
exist := false
for _, param := range field.TypeParams {
if param.Key != common.MaxCapacityKey {
continue
}
maxCapacityPerRow, err := strconv.ParseInt(param.Value, 10, 64)
if err != nil {
return fmt.Errorf("the value for %s of field %s must be an integer", common.MaxCapacityKey, field.GetName())
}
if maxCapacityPerRow > defaultMaxArrayCapacity || maxCapacityPerRow <= 0 {
return errors.New("the maximum capacity specified for a Array should be in (0, 4096]")
}
exist = true
}
// if not exist type params max_length, return error
if !exist {
return fmt.Errorf("type param(max_capacity) should be specified for array field %s of collection %s", field.GetName(), collectionName)
}
return nil
}
func validateVectorFieldMetricType(field *schemapb.FieldSchema) error {
if !typeutil.IsVectorType(field.DataType) {
return nil
}
for _, params := range field.IndexParams {
if params.Key == common.MetricTypeKey {
return nil
}
}
return fmt.Errorf(`index param "metric_type" is not specified for index float vector %s`, field.GetName())
}
func validateDuplicatedFieldName(schema *schemapb.CollectionSchema) error {
names := make(map[string]bool)
validateFieldNames := func(name string) error {
_, ok := names[name]
if ok {
return errors.Newf("duplicated field name %s found", name)
}
names[name] = true
return nil
}
for _, field := range schema.Fields {
if err := validateFieldNames(field.Name); err != nil {
return err
}
}
for _, structArrayField := range schema.StructArrayFields {
if err := validateFieldNames(structArrayField.Name); err != nil {
return err
}
for _, field := range structArrayField.Fields {
if err := validateFieldNames(field.Name); err != nil {
return err
}
}
}
return nil
}
func validateElementType(dataType schemapb.DataType) error {
switch dataType {
case schemapb.DataType_Bool, schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32,
schemapb.DataType_Int64, schemapb.DataType_Float, schemapb.DataType_Double, schemapb.DataType_VarChar:
return nil
case schemapb.DataType_String:
return errors.New("string data type not supported yet, please use VarChar type instead")
case schemapb.DataType_None:
return errors.New("element data type None is not valid")
}
return fmt.Errorf("element type %s is not supported", dataType.String())
}
func validateFieldType(schema *schemapb.CollectionSchema) error {
for _, field := range schema.GetFields() {
switch field.GetDataType() {
case schemapb.DataType_String:
return errors.New("string data type not supported yet, please use VarChar type instead")
case schemapb.DataType_None:
return errors.New("data type None is not valid")
case schemapb.DataType_Array:
if err := validateElementType(field.GetElementType()); err != nil {
return err
}
}
}
for _, structArrayField := range schema.StructArrayFields {
for _, field := range structArrayField.Fields {
if field.GetDataType() != schemapb.DataType_Array && field.GetDataType() != schemapb.DataType_ArrayOfVector {
return errors.Newf("fields in StructArrayField must be Array or ArrayOfVector, field name = %s, field type = %s",
field.GetName(), field.GetDataType().String())
}
}
}
return nil
}
// ValidateFieldAutoID call after validatePrimaryKey
func ValidateFieldAutoID(coll *schemapb.CollectionSchema) error {
idx := -1
for i, field := range coll.Fields {
if field.AutoID {
if idx != -1 {
return fmt.Errorf("only one field can speficy AutoID with true, field name = %s, %s", coll.Fields[idx].Name, field.Name)
}
idx = i
if !field.IsPrimaryKey {
return fmt.Errorf("only primary field can speficy AutoID with true, field name = %s", field.Name)
}
}
}
for _, structArrayField := range coll.StructArrayFields {
for _, field := range structArrayField.Fields {
if field.AutoID {
return errors.Newf("autoID is not supported for struct field, field name = %s", field.Name)
}
}
}
return nil
}
func ValidateField(field *schemapb.FieldSchema, schema *schemapb.CollectionSchema) error {
// validate field name
var err error
if err := validateFieldName(field.Name); err != nil {
return err
}
// validate dense vector field type parameters
isVectorType := typeutil.IsVectorType(field.DataType)
if isVectorType {
err = validateDimension(field)
if err != nil {
return err
}
}
// valid max length per row parameters
// if max_length not specified, return error
if field.DataType == schemapb.DataType_VarChar ||
(field.GetDataType() == schemapb.DataType_Array && field.GetElementType() == schemapb.DataType_VarChar) {
err = validateMaxLengthPerRow(schema.Name, field)
if err != nil {
return err
}
}
// valid max capacity for array per row parameters
// if max_capacity not specified, return error
if field.DataType == schemapb.DataType_Array {
if err = validateMaxCapacityPerRow(schema.Name, field); err != nil {
return err
}
}
if field.DataType == schemapb.DataType_ArrayOfVector {
return fmt.Errorf("array of vector can only be in the struct array field, field name: %s", field.Name)
}
// TODO should remove the index params in the field schema
indexParams := funcutil.KeyValuePair2Map(field.GetIndexParams())
if err = ValidateAutoIndexMmapConfig(isVectorType, indexParams); err != nil {
return err
}
if err := validateAnalyzer(schema, field); err != nil {
return err
}
return nil
}
func ValidateFieldsInStruct(field *schemapb.FieldSchema, schema *schemapb.CollectionSchema) error {
// validate field name
var err error
if err := validateFieldName(field.Name); err != nil {
return err
}
if field.DataType != schemapb.DataType_Array && field.DataType != schemapb.DataType_ArrayOfVector {
return fmt.Errorf("Fields in StructArrayField can only be array or array of struct, but field %s is %s", field.Name, field.DataType.String())
}
if field.ElementType == schemapb.DataType_ArrayOfStruct || field.ElementType == schemapb.DataType_ArrayOfVector ||
field.ElementType == schemapb.DataType_Array {
return fmt.Errorf("Nested array is not supported %s", field.Name)
}
if field.ElementType == schemapb.DataType_JSON {
return fmt.Errorf("JSON is not supported for fields in struct, fieldName = %s", field.Name)
}
if field.DataType == schemapb.DataType_Array {
if typeutil.IsVectorType(field.GetElementType()) {
return fmt.Errorf("Inconsistent schema: element type of array field %s is a vector type", field.Name)
}
} else {
if !typeutil.IsVectorType(field.GetElementType()) {
return fmt.Errorf("Inconsistent schema: element type of array field %s is not a vector type", field.Name)
}
err = validateDimension(field)
if err != nil {
return err
}
}
// valid max length per row parameters
// if max_length not specified, return error
if field.ElementType == schemapb.DataType_VarChar {
err = validateMaxLengthPerRow(schema.Name, field)
if err != nil {
return err
}
}
// todo(SpadeA): make nullable field in struct array supported
if field.GetNullable() {
return fmt.Errorf("nullable is not supported for fields in struct array now, fieldName = %s", field.Name)
}
return nil
}
func ValidateStructArrayField(structArrayField *schemapb.StructArrayFieldSchema, schema *schemapb.CollectionSchema) error {
if len(structArrayField.Fields) == 0 {
return fmt.Errorf("struct array field %s has no sub-fields", structArrayField.Name)
}
for _, subField := range structArrayField.Fields {
if err := ValidateFieldsInStruct(subField, schema); err != nil {
return err
}
}
return nil
}
func validateMultiAnalyzerParams(params string, coll *schemapb.CollectionSchema) error {
var m map[string]json.RawMessage
var analyzerMap map[string]json.RawMessage
var mFileName string
err := json.Unmarshal([]byte(params), &m)
if err != nil {
return err
}
mfield, ok := m["by_field"]
if !ok {
return fmt.Errorf("multi analyzer params now must set by_field to specify with field decide analyzer")
}
err = json.Unmarshal(mfield, &mFileName)
if err != nil {
return fmt.Errorf("multi analyzer params by_field must be string but now: %s", mfield)
}
// check field exist
fieldExist := false
for _, field := range coll.GetFields() {
if field.GetName() == mFileName {
// only support string field now
if field.GetDataType() != schemapb.DataType_VarChar {
return fmt.Errorf("multi analyzer params now only support by string field, but field %s is not string", field.GetName())
}
fieldExist = true
break
}
}
if !fieldExist {
return fmt.Errorf("multi analyzer dependent field %s not exist in collection %s", string(mfield), coll.GetName())
}
if value, ok := m["alias"]; ok {
mapping := map[string]string{}
err = json.Unmarshal(value, &mapping)
if err != nil {
return fmt.Errorf("multi analyzer alias must be string map but now: %s", value)
}
}
analyzers, ok := m["analyzers"]
if !ok {
return fmt.Errorf("multi analyzer params must set analyzers ")
}
err = json.Unmarshal(analyzers, &analyzerMap)
if err != nil {
return fmt.Errorf("unmarshal analyzers failed: %s", err)
}
hasDefault := false
for name, params := range analyzerMap {
if err := analyzer.ValidateAnalyzer(string(params)); err != nil {
return fmt.Errorf("analyzer %s params invalid: %s", name, err)
}
if name == "default" {
hasDefault = true
}
}
if !hasDefault {
return fmt.Errorf("multi analyzer must set default analyzer for all unknown value")
}
return nil
}
func validateAnalyzer(collSchema *schemapb.CollectionSchema, fieldSchema *schemapb.FieldSchema) error {
h := typeutil.CreateFieldSchemaHelper(fieldSchema)
if !h.EnableMatch() && !wasBm25FunctionInputField(collSchema, fieldSchema) {
return nil
}
if !h.EnableAnalyzer() {
return fmt.Errorf("field %s is set to enable match or bm25 function but not enable analyzer", fieldSchema.Name)
}
if params, ok := h.GetMultiAnalyzerParams(); ok {
if h.EnableMatch() {
return fmt.Errorf("multi analyzer now only support for bm25, but now field %s enable match", fieldSchema.Name)
}
if h.HasAnalyzerParams() {
return fmt.Errorf("field %s analyzer params should be none if has multi analyzer params", fieldSchema.Name)
}
return validateMultiAnalyzerParams(params, collSchema)
}
for _, kv := range fieldSchema.GetTypeParams() {
if kv.GetKey() == "analyzer_params" {
return analyzer.ValidateAnalyzer(kv.Value)
}
}
// return nil when use default analyzer
return nil
}
func validatePrimaryKey(coll *schemapb.CollectionSchema) error {
idx := -1
for i, field := range coll.Fields {
if field.IsPrimaryKey {
if idx != -1 {
return fmt.Errorf("there are more than one primary key, field name = %s, %s", coll.Fields[idx].Name, field.Name)
}
// The type of the primary key field can only be int64 and varchar
if field.DataType != schemapb.DataType_Int64 && field.DataType != schemapb.DataType_VarChar {
return errors.New("the data type of primary key should be Int64 or VarChar")
}
// varchar field do not support autoID
// If autoID is required, it is recommended to use int64 field as the primary key
//if field.DataType == schemapb.DataType_VarChar {
// if field.AutoID {
// return errors.New("autoID is not supported when the VarChar field is the primary key")
// }
//}
idx = i
}
}
if idx == -1 {
return errors.New("primary key is not specified")
}
for _, structArrayField := range coll.StructArrayFields {
for _, field := range structArrayField.Fields {
if field.IsPrimaryKey {
return errors.Newf("primary key is not supported for struct field, field name = %s", field.Name)
}
}
}
return nil
}
func validateDynamicField(coll *schemapb.CollectionSchema) error {
for _, field := range coll.Fields {
if field.IsDynamic {
return errors.New("cannot explicitly set a field as a dynamic field")
}
}
return nil
}
// RepeatedKeyValToMap transfer the kv pairs to map.
func RepeatedKeyValToMap(kvPairs []*commonpb.KeyValuePair) (map[string]string, error) {
resMap := make(map[string]string)
for _, kv := range kvPairs {
_, ok := resMap[kv.Key]
if ok {
return nil, fmt.Errorf("duplicated param key: %s", kv.Key)
}
resMap[kv.Key] = kv.Value
}
return resMap, nil
}
// isVector check if dataType belongs to vector type.
func isVector(dataType schemapb.DataType) (bool, error) {
switch dataType {
case schemapb.DataType_Bool, schemapb.DataType_Int8,
schemapb.DataType_Int16, schemapb.DataType_Int32,
schemapb.DataType_Int64,
schemapb.DataType_Float, schemapb.DataType_Double:
return false, nil
case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector:
return true, nil
}
return false, fmt.Errorf("invalid data type: %d", dataType)
}
func validateMetricType(dataType schemapb.DataType, metricTypeStrRaw string) error {
metricTypeStr := strings.ToUpper(metricTypeStrRaw)
switch metricTypeStr {
case metric.L2, metric.IP, metric.COSINE:
if typeutil.IsFloatVectorType(dataType) {
return nil
}
case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD:
if dataType == schemapb.DataType_BinaryVector {
return nil
}
}
return fmt.Errorf("data_type %s mismatch with metric_type %s", dataType.String(), metricTypeStrRaw)
}
func validateFunction(coll *schemapb.CollectionSchema) error {
nameMap := lo.SliceToMap(coll.GetFields(), func(field *schemapb.FieldSchema) (string, *schemapb.FieldSchema) {
return field.GetName(), field
})
usedOutputField := typeutil.NewSet[string]()
usedFunctionName := typeutil.NewSet[string]()
// reset `IsFunctionOuput` despite any user input, this shall be determined by function def only.
for _, field := range coll.Fields {
field.IsFunctionOutput = false
}
for _, function := range coll.GetFunctions() {
if err := checkFunctionBasicParams(function); err != nil {
return err
}
if usedFunctionName.Contain(function.GetName()) {
return fmt.Errorf("duplicate function name: %s", function.GetName())
}
usedFunctionName.Insert(function.GetName())
inputFields := []*schemapb.FieldSchema{}
for _, name := range function.GetInputFieldNames() {
inputField, ok := nameMap[name]
if !ok {
return fmt.Errorf("function input field not found: %s", name)
}
inputFields = append(inputFields, inputField)
}
if err := checkFunctionInputField(function, inputFields); err != nil {
return err
}
outputFields := make([]*schemapb.FieldSchema, len(function.GetOutputFieldNames()))
for i, name := range function.GetOutputFieldNames() {
outputField, ok := nameMap[name]
if !ok {
return fmt.Errorf("function output field not found: %s", name)
}
if outputField.GetIsPrimaryKey() {
return fmt.Errorf("function output field cannot be primary key: function %s, field %s", function.GetName(), outputField.GetName())
}
if outputField.GetIsPartitionKey() || outputField.GetIsClusteringKey() {
return fmt.Errorf("function output field cannot be partition key or clustering key: function %s, field %s", function.GetName(), outputField.GetName())
}
if outputField.GetNullable() {
return fmt.Errorf("function output field cannot be nullable: function %s, field %s", function.GetName(), outputField.GetName())
}
outputField.IsFunctionOutput = true
outputFields[i] = outputField
if usedOutputField.Contain(name) {
return fmt.Errorf("duplicate function output field: function %s, field %s", function.GetName(), name)
}
usedOutputField.Insert(name)
}
if err := checkFunctionOutputField(function, outputFields); err != nil {
return err
}
}
if err := embedding.ValidateFunctions(coll); err != nil {
return err
}
return nil
}
func checkFunctionOutputField(fSchema *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
switch fSchema.GetType() {
case schemapb.FunctionType_BM25:
if len(fields) != 1 {
return fmt.Errorf("BM25 function only need 1 output field, but got %d", len(fields))
}
if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) {
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
}
case schemapb.FunctionType_TextEmbedding:
if err := embedding.TextEmbeddingOutputsCheck(fields); err != nil {
return err
}
default:
return errors.New("check output field for unknown function type")
}
return nil
}
func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
switch function.GetType() {
case schemapb.FunctionType_BM25:
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_VarChar && fields[0].DataType != schemapb.DataType_Text) {
return fmt.Errorf("BM25 function input field must be a VARCHAR/TEXT field, got %d field with type %s",
len(fields), fields[0].DataType.String())
}
h := typeutil.CreateFieldSchemaHelper(fields[0])
if !h.EnableAnalyzer() {
return errors.New("BM25 function input field must set enable_analyzer to true")
}
case schemapb.FunctionType_TextEmbedding:
if err := embedding.TextEmbeddingInputsCheck(function.GetName(), fields); err != nil {
return err
}
default:
return errors.New("check input field with unknown function type")
}
return nil
}
func checkFunctionBasicParams(function *schemapb.FunctionSchema) error {
if function.GetName() == "" {
return errors.New("function name cannot be empty")
}
if len(function.GetInputFieldNames()) == 0 {
return fmt.Errorf("function input field names cannot be empty, function: %s", function.GetName())
}
if len(function.GetOutputFieldNames()) == 0 {
return fmt.Errorf("function output field names cannot be empty, function: %s", function.GetName())
}
for _, input := range function.GetInputFieldNames() {
if input == "" {
return fmt.Errorf("function input field name cannot be empty string, function: %s", function.GetName())
}
// if input occurs more than once, error
if lo.Count(function.GetInputFieldNames(), input) > 1 {
return fmt.Errorf("each function input field should be used exactly once in the same function, function: %s, input field: %s", function.GetName(), input)
}
}
for _, output := range function.GetOutputFieldNames() {
if output == "" {
return fmt.Errorf("function output field name cannot be empty string, function: %s", function.GetName())
}
if lo.Count(function.GetInputFieldNames(), output) > 0 {
return fmt.Errorf("a single field cannot be both input and output in the same function, function: %s, field: %s", function.GetName(), output)
}
if lo.Count(function.GetOutputFieldNames(), output) > 1 {
return fmt.Errorf("each function output field should be used exactly once in the same function, function: %s, output field: %s", function.GetName(), output)
}
}
switch function.GetType() {
case schemapb.FunctionType_BM25:
if len(function.GetParams()) != 0 {
return errors.New("BM25 function accepts no params")
}
case schemapb.FunctionType_TextEmbedding:
if len(function.GetParams()) == 0 {
return errors.New("TextEmbedding function accepts no params")
}
default:
return errors.New("check function params with unknown function type")
}
return nil
}
// validateMultipleVectorFields check if schema has multiple vector fields.
func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
vecExist := false
var vecName string
for i := range schema.Fields {
name := schema.Fields[i].Name
dType := schema.Fields[i].DataType
isVec := typeutil.IsVectorType(dType)
if isVec && vecExist && !enableMultipleVectorFields {
return fmt.Errorf(
"multiple vector fields is not supported, fields name: %s, %s",
vecName,
name,
)
} else if isVec {
vecExist = true
vecName = name
}
}
// todo(Spadea): should be there any check between vectors in struct fields?
return nil
}
func validateLoadFieldsList(schema *schemapb.CollectionSchema) error {
var vectorCnt int
for _, field := range schema.Fields {
shouldLoad, err := common.ShouldFieldBeLoaded(field.GetTypeParams())
if err != nil {
return err
}
// shoud load field, skip other check
if shouldLoad {
if typeutil.IsVectorType(field.GetDataType()) {
vectorCnt++
}
continue
}
if field.IsPrimaryKey {
return merr.WrapErrParameterInvalidMsg("Primary key field %s cannot skip loading", field.GetName())
}
if field.IsPartitionKey {
return merr.WrapErrParameterInvalidMsg("Partition Key field %s cannot skip loading", field.GetName())
}
if field.IsClusteringKey {
return merr.WrapErrParameterInvalidMsg("Clustering Key field %s cannot skip loading", field.GetName())
}
}
for _, structArrayField := range schema.StructArrayFields {
for _, field := range structArrayField.Fields {
shouldLoad, err := common.ShouldFieldBeLoaded(field.GetTypeParams())
if err != nil {
return err
}
if shouldLoad {
if typeutil.IsVectorType(field.ElementType) {
vectorCnt++
}
continue
}
}
}
if vectorCnt == 0 {
return merr.WrapErrParameterInvalidMsg("cannot config all vector field(s) skip loading")
}
return nil
}
// parsePrimaryFieldData2IDs get IDs to fill grpc result, for example insert request, delete request etc.
func parsePrimaryFieldData2IDs(fieldData *schemapb.FieldData) (*schemapb.IDs, error) {
primaryData := &schemapb.IDs{}
switch fieldData.Field.(type) {
case *schemapb.FieldData_Scalars:
scalarField := fieldData.GetScalars()
switch scalarField.Data.(type) {
case *schemapb.ScalarField_LongData:
primaryData.IdField = &schemapb.IDs_IntId{
IntId: scalarField.GetLongData(),
}
case *schemapb.ScalarField_StringData:
primaryData.IdField = &schemapb.IDs_StrId{
StrId: scalarField.GetStringData(),
}
default:
return nil, merr.WrapErrParameterInvalidMsg("currently only support DataType Int64 or VarChar as PrimaryField")
}
default:
return nil, merr.WrapErrParameterInvalidMsg("currently not support vector field as PrimaryField")
}
return primaryData, nil
}
// autoGenPrimaryFieldData generate primary data when autoID == true
func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}) (*schemapb.FieldData, error) {
var fieldData schemapb.FieldData
fieldData.FieldName = fieldSchema.Name
fieldData.Type = fieldSchema.DataType
switch data := data.(type) {
case []int64:
switch fieldData.Type {
case schemapb.DataType_Int64:
fieldData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: data,
},
},
},
}
case schemapb.DataType_VarChar:
strIDs := make([]string, len(data))
for i, v := range data {
strIDs[i] = strconv.FormatInt(v, 10)
}
fieldData.Field = &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: strIDs,
},
},
},
}
default:
return nil, errors.New("currently only support autoID for int64 and varchar PrimaryField")
}
default:
return nil, errors.New("currently only int64 is supported as the data source for the autoID of a PrimaryField")
}
return &fieldData, nil
}
func autoGenDynamicFieldData(data [][]byte) *schemapb.FieldData {
return &schemapb.FieldData{
FieldName: common.MetaFieldName,
Type: schemapb.DataType_JSON,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_JsonData{
JsonData: &schemapb.JSONArray{
Data: data,
},
},
},
},
IsDynamic: true,
}
}
// fillFieldPropertiesBySchema set fieldID to fieldData according FieldSchemas
func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error {
fieldName2Schema := make(map[string]*schemapb.FieldSchema)
expectColumnNum := 0
for _, field := range schema.GetFields() {
fieldName2Schema[field.Name] = field
if !IsBM25FunctionOutputField(field, schema) {
expectColumnNum++
}
}
for _, structField := range schema.GetStructArrayFields() {
for _, field := range structField.GetFields() {
fieldName2Schema[field.Name] = field
expectColumnNum++
}
}
if len(columns) != expectColumnNum {
return fmt.Errorf("len(columns) mismatch the expectColumnNum, expectColumnNum: %d, len(columns): %d",
expectColumnNum, len(columns))
}
for _, fieldData := range columns {
if fieldSchema, ok := fieldName2Schema[fieldData.FieldName]; ok {
fieldData.FieldId = fieldSchema.FieldID
fieldData.Type = fieldSchema.DataType
// Set the ElementType because it may not be set in the insert request.
if fieldData.Type == schemapb.DataType_Array {
fd, ok := fieldData.Field.(*schemapb.FieldData_Scalars)
if !ok || fd.Scalars.GetArrayData() == nil {
return fmt.Errorf("field convert FieldData_Scalars fail in fieldData, fieldName: %s,"+
" collectionName:%s", fieldData.FieldName, schema.Name)
}
fd.Scalars.GetArrayData().ElementType = fieldSchema.ElementType
} else if fieldData.Type == schemapb.DataType_ArrayOfVector {
fd, ok := fieldData.Field.(*schemapb.FieldData_Vectors)
if !ok || fd.Vectors.GetVectorArray() == nil {
return fmt.Errorf("field convert FieldData_Vectors fail in fieldData, fieldName: %s,"+
" collectionName:%s", fieldData.FieldName, schema.Name)
}
fd.Vectors.GetVectorArray().ElementType = fieldSchema.ElementType
}
} else {
return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName)
}
}
return nil
}
func ValidateUsername(username string) error {
username = strings.TrimSpace(username)
if username == "" {
return merr.WrapErrParameterInvalidMsg("username must be not empty")
}
if len(username) > Params.ProxyCfg.MaxUsernameLength.GetAsInt() {
return merr.WrapErrParameterInvalidMsg("invalid username %s with length %d, the length of username must be less than %d", username, len(username), Params.ProxyCfg.MaxUsernameLength.GetValue())
}
firstChar := username[0]
if !isAlpha(firstChar) {
return merr.WrapErrParameterInvalidMsg("invalid user name %s, the first character must be a letter, but got %s", username, string(firstChar))
}
usernameSize := len(username)
for i := 1; i < usernameSize; i++ {
c := username[i]
if c != '_' && c != '-' && c != '.' && !isAlpha(c) && !isNumber(c) {
return merr.WrapErrParameterInvalidMsg("invalid user name %s, username must contain only numbers, letters, underscores, dots, and hyphens, but got %s", username, c)
}
}
return nil
}
func ValidatePassword(password string) error {
if len(password) < Params.ProxyCfg.MinPasswordLength.GetAsInt() || len(password) > Params.ProxyCfg.MaxPasswordLength.GetAsInt() {
return merr.WrapErrParameterInvalidRange(Params.ProxyCfg.MinPasswordLength.GetAsInt(),
Params.ProxyCfg.MaxPasswordLength.GetAsInt(),
len(password), "invalid password length")
}
return nil
}
func ReplaceID2Name(oldStr string, id int64, name string) string {
return strings.ReplaceAll(oldStr, strconv.FormatInt(id, 10), name)
}
func parseGuaranteeTsFromConsistency(ts, tMax typeutil.Timestamp, consistency commonpb.ConsistencyLevel) typeutil.Timestamp {
switch consistency {
case commonpb.ConsistencyLevel_Strong:
ts = tMax
case commonpb.ConsistencyLevel_Bounded:
ratio := Params.CommonCfg.GracefulTime.GetAsDuration(time.Millisecond)
ts = tsoutil.AddPhysicalDurationOnTs(tMax, -ratio)
case commonpb.ConsistencyLevel_Eventually:
ts = 1
}
return ts
}
func parseGuaranteeTs(ts, tMax typeutil.Timestamp) typeutil.Timestamp {
switch ts {
case strongTS:
ts = tMax
case boundedTS:
ratio := Params.CommonCfg.GracefulTime.GetAsDuration(time.Millisecond)
ts = tsoutil.AddPhysicalDurationOnTs(tMax, -ratio)
}
return ts
}
func getMaxMvccTsFromChannels(channelsTs map[string]uint64, beginTs typeutil.Timestamp) typeutil.Timestamp {
maxTs := typeutil.Timestamp(0)
for _, ts := range channelsTs {
if ts > maxTs {
maxTs = ts
}
}
if maxTs == 0 {
log.Warn("no channel ts found, use beginTs instead")
return beginTs
}
return maxTs
}
func validateName(entity string, nameType string) error {
return validateNameWithCustomChars(entity, nameType, Params.ProxyCfg.NameValidationAllowedChars.GetValue())
}
func validateNameWithCustomChars(entity string, nameType string, allowedChars string) error {
entity = strings.TrimSpace(entity)
if entity == "" {
return merr.WrapErrParameterInvalid("not empty", entity, nameType+" should be not empty")
}
if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
return merr.WrapErrParameterInvalidRange(0,
Params.ProxyCfg.MaxNameLength.GetAsInt(),
len(entity),
fmt.Sprintf("the length of %s must be not greater than limit", nameType))
}
firstChar := entity[0]
if firstChar != '_' && !isAlpha(firstChar) {
return merr.WrapErrParameterInvalid('_',
firstChar,
fmt.Sprintf("the first character of %s must be an underscore or letter", nameType))
}
for i := 1; i < len(entity); i++ {
c := entity[i]
if c != '_' && !isAlpha(c) && !isNumber(c) && !strings.ContainsRune(allowedChars, rune(c)) {
return merr.WrapErrParameterInvalidMsg("%s can only contain numbers, letters, underscores, and allowed characters (%s), found %c at %d", nameType, allowedChars, c, i)
}
}
return nil
}
func ValidateRoleName(entity string) error {
return validateNameWithCustomChars(entity, "role name", Params.ProxyCfg.RoleNameValidationAllowedChars.GetValue())
}
func IsDefaultRole(roleName string) bool {
for _, defaultRole := range util.DefaultRoles {
if defaultRole == roleName {
return true
}
}
return false
}
func ValidateObjectName(entity string) error {
if util.IsAnyWord(entity) {
return nil
}
return validateName(entity, "object name")
}
func ValidateCollectionName(entity string) error {
if util.IsAnyWord(entity) {
return nil
}
return validateName(entity, "collection name")
}
func ValidateObjectType(entity string) error {
return validateName(entity, "ObjectType")
}
func ValidatePrivilege(entity string) error {
if util.IsAnyWord(entity) {
return nil
}
return validateName(entity, "Privilege")
}
func GetCurUserFromContext(ctx context.Context) (string, error) {
return contextutil.GetCurUserFromContext(ctx)
}
func GetCurUserFromContextOrDefault(ctx context.Context) string {
username, _ := GetCurUserFromContext(ctx)
return username
}
func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return util.DefaultDBName
}
dbNameData := md[strings.ToLower(util.HeaderDBName)]
if len(dbNameData) < 1 || dbNameData[0] == "" {
return util.DefaultDBName
}
return dbNameData[0]
}
func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
dbKey := strings.ToLower(util.HeaderDBName)
if dbName != "" {
ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
}
if username != "" {
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue)
}
return ctx
}
func AppendUserInfoForRPC(ctx context.Context) context.Context {
curUser, _ := GetCurUserFromContext(ctx)
if curUser != "" {
originValue := fmt.Sprintf("%s%s%s", curUser, util.CredentialSeperator, curUser)
authKey := strings.ToLower(util.HeaderAuthorize)
authValue := crypto.Base64Encode(originValue)
ctx = metadata.AppendToOutgoingContext(ctx, authKey, authValue)
}
return ctx
}
func GetRole(username string) ([]string, error) {
privCache := privilege.GetPrivilegeCache()
if privCache == nil {
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
}
return privCache.GetUserRole(username), nil
}
func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
return passwordVerify(ctx, username, rawPwd, privilege.GetPrivilegeCache())
}
func VerifyAPIKey(rawToken string) (string, error) {
hoo := hookutil.GetHook()
user, err := hoo.VerifyAPIKey(rawToken)
if err != nil {
log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err))
return "", merr.WrapErrParameterInvalidMsg("invalid apikey: [%s]", rawToken)
}
return user, nil
}
// PasswordVerify verify password
func passwordVerify(ctx context.Context, username, rawPwd string, privilegeCache privilege.PrivilegeCache) bool {
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.
// meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache.
credInfo, err := privilege.GetPrivilegeCache().GetCredentialInfo(ctx, username)
if err != nil {
log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err))
return false
}
// hit cache
sha256Pwd := crypto.SHA256(rawPwd, credInfo.Username)
if credInfo.Sha256Password != "" {
return sha256Pwd == credInfo.Sha256Password
}
// miss cache, verify against encrypted password from etcd
if err := bcrypt.CompareHashAndPassword([]byte(credInfo.EncryptedPassword), []byte(rawPwd)); err != nil {
log.Ctx(ctx).Error("Verify password failed", zap.Error(err))
return false
}
// update cache after miss cache
credInfo.Sha256Password = sha256Pwd
log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo))
privilegeCache.UpdateCredential(credInfo)
return true
}
func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int64) {
pkNames := []string{}
fieldIDs := []int64{}
for _, field := range schema.Fields {
if field.IsPrimaryKey {
pkNames = append(pkNames, field.GetName())
fieldIDs = append(fieldIDs, field.GetFieldID())
}
}
return pkNames, fieldIDs
}
func recallCal[T string | int64](results []T, gts []T) float32 {
hit := 0
total := 0
for _, r := range results {
total++
for _, gt := range gts {
if r == gt {
hit++
break
}
}
}
return float32(hit) / float32(total)
}
func computeRecall(results *schemapb.SearchResultData, gts *schemapb.SearchResultData) error {
if results.GetNumQueries() != gts.GetNumQueries() {
return fmt.Errorf("num of queries is inconsistent between search results(%d) and ground truth(%d)", results.GetNumQueries(), gts.GetNumQueries())
}
switch results.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
switch gts.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
currentResultIndex := int64(0)
currentGTIndex := int64(0)
recalls := make([]float32, 0, results.GetNumQueries())
for i := 0; i < int(results.GetNumQueries()); i++ {
currentResultTopk := results.GetTopks()[i]
currentGTTopk := gts.GetTopks()[i]
recalls = append(recalls, recallCal(results.GetIds().GetIntId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
gts.GetIds().GetIntId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
currentResultIndex += currentResultTopk
currentGTIndex += currentGTTopk
}
results.Recalls = recalls
return nil
case *schemapb.IDs_StrId:
return errors.New("pk type is inconsistent between search results(int64) and ground truth(string)")
default:
return errors.New("unsupported pk type")
}
case *schemapb.IDs_StrId:
switch gts.GetIds().GetIdField().(type) {
case *schemapb.IDs_StrId:
currentResultIndex := int64(0)
currentGTIndex := int64(0)
recalls := make([]float32, 0, results.GetNumQueries())
for i := 0; i < int(results.GetNumQueries()); i++ {
currentResultTopk := results.GetTopks()[i]
currentGTTopk := gts.GetTopks()[i]
recalls = append(recalls, recallCal(results.GetIds().GetStrId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
gts.GetIds().GetStrId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
currentResultIndex += currentResultTopk
currentGTIndex += currentGTTopk
}
results.Recalls = recalls
return nil
case *schemapb.IDs_IntId:
return errors.New("pk type is inconsistent between search results(string) and ground truth(int64)")
default:
return errors.New("unsupported pk type")
}
default:
return errors.New("unsupported pk type")
}
}
// Support wildcard in output fields:
//
// "*" - all fields
//
// For example, A and B are scalar fields, C and D are vector fields, duplicated fields will automatically be removed.
//
// output_fields=["*"] ==> [A,B,C,D]
// output_fields=["*",A] ==> [A,B,C,D]
// output_fields=["*",C] ==> [A,B,C,D]
//
// 4th return value is true if user requested pk field explicitly or using wildcard.
// if removePkField is true, pk field will not be include in the first(resultFieldNames)/second(userOutputFields)
// return value.
func translateOutputFields(outputFields []string, schema *schemaInfo, removePkField bool) ([]string, []string, []string, bool, error) {
var primaryFieldName string
allFieldNameMap := make(map[string]*schemapb.FieldSchema)
resultFieldNameMap := make(map[string]bool)
resultFieldNames := make([]string, 0)
userOutputFieldsMap := make(map[string]bool)
userOutputFields := make([]string, 0)
userDynamicFieldsMap := make(map[string]bool)
userDynamicFields := make([]string, 0)
useAllDyncamicFields := false
for _, field := range schema.Fields {
if field.IsPrimaryKey {
primaryFieldName = field.Name
}
allFieldNameMap[field.Name] = field
}
// User may specify a struct array field or some specific fields in the struct array field
for _, subStruct := range schema.StructArrayFields {
for _, field := range subStruct.Fields {
allFieldNameMap[field.Name] = field
}
}
structArrayNameToFields := make(map[string][]*schemapb.FieldSchema)
for _, subStruct := range schema.StructArrayFields {
structArrayNameToFields[subStruct.Name] = subStruct.Fields
}
userRequestedPkFieldExplicitly := false
for _, outputFieldName := range outputFields {
outputFieldName = strings.TrimSpace(outputFieldName)
if outputFieldName == primaryFieldName {
userRequestedPkFieldExplicitly = true
}
if outputFieldName == "*" {
userRequestedPkFieldExplicitly = true
for fieldName, field := range allFieldNameMap {
if schema.CanRetrieveRawFieldData(field) {
resultFieldNameMap[fieldName] = true
userOutputFieldsMap[fieldName] = true
}
}
useAllDyncamicFields = true
} else {
if structArrayField, ok := structArrayNameToFields[outputFieldName]; ok {
for _, field := range structArrayField {
if schema.CanRetrieveRawFieldData(field) {
resultFieldNameMap[field.Name] = true
userOutputFieldsMap[field.Name] = true
}
}
continue
}
if field, ok := allFieldNameMap[outputFieldName]; ok {
if !schema.CanRetrieveRawFieldData(field) {
return nil, nil, nil, false, fmt.Errorf("not allowed to retrieve raw data of field %s", outputFieldName)
}
resultFieldNameMap[outputFieldName] = true
userOutputFieldsMap[outputFieldName] = true
} else {
if schema.EnableDynamicField {
dynamicNestedPath := outputFieldName
err := planparserv2.ParseIdentifier(schema.schemaHelper, outputFieldName, func(expr *planpb.Expr) error {
columnInfo := expr.GetColumnExpr().GetInfo()
// there must be no error here
dynamicField, _ := schema.schemaHelper.GetDynamicField()
// only $meta["xxx"] is allowed for now
if dynamicField.GetFieldID() != columnInfo.GetFieldId() {
return errors.New("not support getting subkeys of json field yet")
}
nestedPaths := columnInfo.GetNestedPath()
// $meta["A"]["B"] not allowed for now
if len(nestedPaths) != 1 {
return errors.New("not support getting multiple level of dynamic field for now")
}
// $meta["dyn_field"], output field name could be:
// 1. "dyn_field", outputFieldName == nestedPath
// 2. `$meta["dyn_field"]` explicit form
if nestedPaths[0] != outputFieldName {
// use "dyn_field" as userDynamicFieldsMap when outputField = `$meta["dyn_field"]`
dynamicNestedPath = nestedPaths[0]
}
return nil
})
if err != nil {
log.Info("parse output field name failed", zap.String("field name", outputFieldName), zap.Error(err))
return nil, nil, nil, false, fmt.Errorf("parse output field name failed: %s", outputFieldName)
}
resultFieldNameMap[common.MetaFieldName] = true
userOutputFieldsMap[outputFieldName] = true
userDynamicFieldsMap[dynamicNestedPath] = true
} else {
return nil, nil, nil, false, fmt.Errorf("field %s not exist", outputFieldName)
}
}
}
}
if removePkField {
delete(resultFieldNameMap, primaryFieldName)
delete(userOutputFieldsMap, primaryFieldName)
}
for fieldName := range resultFieldNameMap {
resultFieldNames = append(resultFieldNames, fieldName)
}
for fieldName := range userOutputFieldsMap {
userOutputFields = append(userOutputFields, fieldName)
}
if !useAllDyncamicFields {
for fieldName := range userDynamicFieldsMap {
userDynamicFields = append(userDynamicFields, fieldName)
}
}
return resultFieldNames, userOutputFields, userDynamicFields, userRequestedPkFieldExplicitly, nil
}
func validateIndexName(indexName string) error {
indexName = strings.TrimSpace(indexName)
if indexName == "" {
return nil
}
invalidMsg := "Invalid index name: " + indexName + ". "
if len(indexName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
msg := invalidMsg + "The length of a index name must be less than " + Params.ProxyCfg.MaxNameLength.GetValue() + " characters."
return errors.New(msg)
}
firstChar := indexName[0]
if firstChar != '_' && !isAlpha(firstChar) {
msg := invalidMsg + "The first character of a index name must be an underscore or letter."
return errors.New(msg)
}
indexNameSize := len(indexName)
for i := 1; i < indexNameSize; i++ {
c := indexName[i]
if c != '_' && !isAlpha(c) && !isNumber(c) {
msg := invalidMsg + "Index name can only contain numbers, letters, and underscores."
return errors.New(msg)
}
}
return nil
}
func isCollectionLoaded(ctx context.Context, mc types.MixCoordClient, collID int64) (bool, error) {
// get all loading collections
resp, err := mc.ShowLoadCollections(ctx, &querypb.ShowCollectionsRequest{
CollectionIDs: nil,
})
if err != nil {
return false, err
}
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
return false, merr.Error(resp.GetStatus())
}
for _, loadedCollID := range resp.GetCollectionIDs() {
if collID == loadedCollID {
return true, nil
}
}
return false, nil
}
func isPartitionLoaded(ctx context.Context, mc types.MixCoordClient, collID int64, partID int64) (bool, error) {
// get all loading collections
resp, err := mc.ShowLoadPartitions(ctx, &querypb.ShowPartitionsRequest{
CollectionID: collID,
PartitionIDs: []int64{partID},
})
if err := merr.CheckRPCCall(resp, err); err != nil {
// qc returns error if partition not loaded
if errors.Is(err, merr.ErrPartitionNotLoaded) {
return false, nil
}
return false, err
}
return true, nil
}
func checkFieldsDataBySchema(allFields []*schemapb.FieldSchema, schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg, inInsert bool) error {
log := log.With(zap.String("collection", schema.GetName()))
primaryKeyNum := 0
autoGenFieldNum := 0
dataNameSet := typeutil.NewSet[string]()
for _, data := range insertMsg.FieldsData {
fieldName := data.GetFieldName()
if dataNameSet.Contain(fieldName) {
return merr.WrapErrParameterInvalidMsg("duplicated field %s found", fieldName)
}
dataNameSet.Insert(fieldName)
}
allowInsertAutoID, _ := common.IsAllowInsertAutoID(schema.GetProperties()...)
hasPkData := false
needAutoGenPk := false
for _, fieldSchema := range allFields {
if fieldSchema.AutoID && !fieldSchema.IsPrimaryKey {
log.Warn("not primary key field, but set autoID true", zap.String("field", fieldSchema.GetName()))
return merr.WrapErrParameterInvalidMsg("only primary key could be with AutoID enabled")
}
if fieldSchema.IsPrimaryKey {
primaryKeyNum++
hasPkData = dataNameSet.Contain(fieldSchema.GetName())
needAutoGenPk = fieldSchema.AutoID && (!allowInsertAutoID || !hasPkData)
}
if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey {
return merr.WrapErrParameterInvalidMsg("primary key can't be with default value")
}
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && needAutoGenPk && inInsert) || IsBM25FunctionOutputField(fieldSchema, schema) {
// when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false
autoGenFieldNum++
}
if _, ok := dataNameSet[fieldSchema.GetName()]; !ok {
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && needAutoGenPk && inInsert) || IsBM25FunctionOutputField(fieldSchema, schema) {
// autoGenField
continue
}
if fieldSchema.GetDefaultValue() == nil && !fieldSchema.GetNullable() {
log.Warn("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName()))
return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName())
}
// when use default_value or has set Nullable
// it's ok that no corresponding fieldData found
dataToAppend, err := typeutil.GenEmptyFieldData(fieldSchema)
if err != nil {
return err
}
dataToAppend.ValidData = make([]bool, insertMsg.GetNumRows())
insertMsg.FieldsData = append(insertMsg.FieldsData, dataToAppend)
}
}
if primaryKeyNum > 1 {
log.Warn("more than 1 primary keys not supported",
zap.Int64("primaryKeyNum", int64(primaryKeyNum)))
return merr.WrapErrParameterInvalidMsg("more than 1 primary keys not supported, got %d", primaryKeyNum)
}
expectedNum := len(allFields)
actualNum := len(insertMsg.FieldsData) + autoGenFieldNum
if expectedNum != actualNum {
log.Warn("the number of fields is not the same as needed", zap.Int("expected", expectedNum), zap.Int("actual", actualNum))
return merr.WrapErrParameterInvalid(expectedNum, actualNum, "more fieldData has pass in")
}
return nil
}
// checkAndFlattenStructFieldData verifies the array length of the struct array field data in the insert message
// and then flattens the data so that data node and query node have not to handle the struct array field data.
func checkAndFlattenStructFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
structSchemaMap := make(map[string]*schemapb.StructArrayFieldSchema, len(schema.GetStructArrayFields()))
for _, structField := range schema.GetStructArrayFields() {
structSchemaMap[structField.Name] = structField
}
fieldSchemaMap := make(map[string]*schemapb.FieldSchema, len(schema.GetFields()))
for _, fieldSchema := range schema.GetFields() {
fieldSchemaMap[fieldSchema.Name] = fieldSchema
}
structFieldCount := 0
flattenedFields := make([]*schemapb.FieldData, 0, len(insertMsg.GetFieldsData())+5)
for _, fieldData := range insertMsg.GetFieldsData() {
if _, ok := fieldSchemaMap[fieldData.FieldName]; ok {
flattenedFields = append(flattenedFields, fieldData)
continue
}
structName := fieldData.FieldName
structSchema, ok := structSchemaMap[structName]
if !ok {
return fmt.Errorf("fieldName %v not exist in collection schema, fieldType %v, fieldId %v", fieldData.FieldName, fieldData.Type, fieldData.FieldId)
}
structFieldCount++
structArrays, ok := fieldData.Field.(*schemapb.FieldData_StructArrays)
if !ok {
return fmt.Errorf("field convert FieldData_StructArrays fail in fieldData, fieldName: %s,"+
" collectionName:%s", structName, schema.Name)
}
if len(structArrays.StructArrays.Fields) != len(structSchema.GetFields()) {
return fmt.Errorf("length of fields of struct field mismatch length of the fields in schema, fieldName: %s,"+
" collectionName:%s, fieldData fields length:%d, schema fields length:%d",
structName, schema.Name, len(structArrays.StructArrays.Fields), len(structSchema.GetFields()))
}
// Check the array length of the struct array field data
expectedArrayLen := -1
for _, subField := range structArrays.StructArrays.Fields {
var currentArrayLen int
switch subFieldData := subField.Field.(type) {
case *schemapb.FieldData_Scalars:
if scalarArray := subFieldData.Scalars.GetArrayData(); scalarArray != nil {
currentArrayLen = len(scalarArray.Data)
} else {
return fmt.Errorf("scalar array data is nil in struct field '%s', sub-field '%s'",
structName, subField.FieldName)
}
case *schemapb.FieldData_Vectors:
if vectorArray := subFieldData.Vectors.GetVectorArray(); vectorArray != nil {
currentArrayLen = len(vectorArray.Data)
} else {
return fmt.Errorf("vector array data is nil in struct field '%s', sub-field '%s'",
structName, subField.FieldName)
}
default:
return fmt.Errorf("unexpected field data type in struct array field, fieldName: %s", structName)
}
if expectedArrayLen == -1 {
expectedArrayLen = currentArrayLen
} else if currentArrayLen != expectedArrayLen {
return fmt.Errorf("inconsistent array length in struct field '%s': expected %d, got %d for sub-field '%s'",
structName, expectedArrayLen, currentArrayLen, subField.FieldName)
}
transformedFieldName := ConcatStructFieldName(structName, subField.FieldName)
subFieldCopy := &schemapb.FieldData{
FieldName: transformedFieldName,
FieldId: subField.FieldId,
Type: subField.Type,
Field: subField.Field,
IsDynamic: subField.IsDynamic,
}
flattenedFields = append(flattenedFields, subFieldCopy)
}
}
if len(schema.GetStructArrayFields()) != structFieldCount {
return fmt.Errorf("the number of struct array fields is not the same as needed, expected: %d, actual: %d",
len(schema.GetStructArrayFields()), structFieldCount)
}
insertMsg.FieldsData = flattenedFields
return nil
}
func checkPrimaryFieldData(allFields []*schemapb.FieldSchema, schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) {
log := log.With(zap.String("collectionName", insertMsg.CollectionName))
rowNums := uint32(insertMsg.NRows())
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
if insertMsg.NRows() <= 0 {
return nil, merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(rowNums), "num_rows should be greater than 0")
}
if err := checkFieldsDataBySchema(allFields, schema, insertMsg, true); err != nil {
return nil, err
}
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
log.Error("get primary field schema failed", zap.Any("schema", schema), zap.Error(err))
return nil, err
}
if primaryFieldSchema.GetNullable() {
return nil, merr.WrapErrParameterInvalidMsg("primary field not support null")
}
var primaryFieldData *schemapb.FieldData
// when checkPrimaryFieldData in insert
allowInsertAutoID, _ := common.IsAllowInsertAutoID(schema.GetProperties()...)
skipAutoIDCheck := primaryFieldSchema.AutoID &&
typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) && (Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() || allowInsertAutoID)
if !primaryFieldSchema.AutoID || skipAutoIDCheck {
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
if err != nil {
log.Info("get primary field data failed", zap.Error(err))
return nil, err
}
} else {
// check primary key data not exist
if typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) {
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("can not assign primary field data when auto id enabled and allow_insert_auto_id is false %v", primaryFieldSchema.Name))
}
// if autoID == true, currently support autoID for int64 and varchar PrimaryField
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
if err != nil {
log.Info("generate primary field data failed when autoID == true", zap.Error(err))
return nil, err
}
// if autoID == true, set the primary field data
// insertMsg.fieldsData need append primaryFieldData
insertMsg.FieldsData = append(insertMsg.FieldsData, primaryFieldData)
}
// parse primaryFieldData to result.IDs, and as returned primary keys
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
if err != nil {
log.Warn("parse primary field data to IDs failed", zap.Error(err))
return nil, err
}
return ids, nil
}
// check whether insertMsg has all fields in schema
func LackOfFieldsDataBySchema(schema *schemapb.CollectionSchema, fieldsData []*schemapb.FieldData, skipPkFieldCheck bool, skipDynamicFieldCheck bool) error {
log := log.With(zap.String("collection", schema.GetName()))
// find bm25 generated fields
bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(schema)...)
dataNameMap := make(map[string]*schemapb.FieldData)
for _, data := range fieldsData {
dataNameMap[data.GetFieldName()] = data
}
for _, fieldSchema := range schema.Fields {
if bm25Fields.Contain(fieldSchema.GetName()) {
continue
}
if fieldSchema.GetNullable() || fieldSchema.GetDefaultValue() != nil {
continue
}
if _, ok := dataNameMap[fieldSchema.GetName()]; !ok {
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && skipPkFieldCheck) ||
IsBM25FunctionOutputField(fieldSchema, schema) ||
(skipDynamicFieldCheck && fieldSchema.GetIsDynamic()) {
// autoGenField
continue
}
log.Info("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName()))
return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName())
}
}
return nil
}
// for some varchar with analzyer
// we need check char format before insert it to message queue
// now only support utf-8
func checkInputUtf8Compatiable(allFields []*schemapb.FieldSchema, insertMsg *msgstream.InsertMsg) error {
checkeFields := lo.FilterMap(allFields, func(field *schemapb.FieldSchema, _ int) (int64, bool) {
if field.DataType == schemapb.DataType_VarChar {
return field.GetFieldID(), true
}
if field.DataType != schemapb.DataType_Text {
return 0, false
}
for _, kv := range field.GetTypeParams() {
if kv.Key == common.EnableAnalyzerKey {
return field.GetFieldID(), true
}
}
return 0, false
})
if len(checkeFields) == 0 {
return nil
}
for _, fieldData := range insertMsg.FieldsData {
if !lo.Contains(checkeFields, fieldData.GetFieldId()) {
continue
}
strData := fieldData.GetScalars().GetStringData()
for row, data := range strData.GetData() {
ok := utf8.ValidString(data)
if !ok {
log.Warn("string field data not utf-8 format", zap.String("messageVersion", strData.ProtoReflect().Descriptor().Syntax().GoString()))
return merr.WrapErrAsInputError(fmt.Errorf("input with analyzer should be utf-8 format, but row: %d not utf-8 format. data: %s", row, data))
}
}
}
return nil
}
func checkUpsertPrimaryFieldData(allFields []*schemapb.FieldSchema, schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, *schemapb.IDs, error) {
log := log.With(zap.String("collectionName", insertMsg.CollectionName))
rowNums := uint32(insertMsg.NRows())
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
if insertMsg.NRows() <= 0 {
return nil, nil, merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(rowNums), "num_rows should be greater than 0")
}
if err := checkFieldsDataBySchema(allFields, schema, insertMsg, false); err != nil {
return nil, nil, err
}
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
log.Error("get primary field schema failed", zap.Any("schema", schema), zap.Error(err))
return nil, nil, err
}
if primaryFieldSchema.GetNullable() {
return nil, nil, merr.WrapErrParameterInvalidMsg("primary field not support null")
}
// get primaryFieldData whether autoID is true or not
var primaryFieldData *schemapb.FieldData
var newPrimaryFieldData *schemapb.FieldData
primaryFieldID := primaryFieldSchema.FieldID
primaryFieldName := primaryFieldSchema.Name
for i, field := range insertMsg.GetFieldsData() {
if field.FieldId == primaryFieldID || field.FieldName == primaryFieldName {
primaryFieldData = field
if primaryFieldSchema.AutoID {
// use the passed pk as new pk when autoID == false
// automatic generate pk as new pk wehen autoID == true
newPrimaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
if err != nil {
log.Info("generate new primary field data failed when upsert", zap.Error(err))
return nil, nil, err
}
insertMsg.FieldsData = append(insertMsg.GetFieldsData()[:i], insertMsg.GetFieldsData()[i+1:]...)
insertMsg.FieldsData = append(insertMsg.FieldsData, newPrimaryFieldData)
}
break
}
}
// must assign primary field data when upsert
if primaryFieldData == nil {
return nil, nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldName))
}
// parse primaryFieldData to result.IDs, and as returned primary keys
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
if err != nil {
log.Warn("parse primary field data to IDs failed", zap.Error(err))
return nil, nil, err
}
if !primaryFieldSchema.GetAutoID() {
return ids, ids, nil
}
newIDs, err := parsePrimaryFieldData2IDs(newPrimaryFieldData)
if err != nil {
log.Warn("parse primary field data to IDs failed", zap.Error(err))
return nil, nil, err
}
return newIDs, ids, nil
}
func getPartitionKeyFieldData(fieldSchema *schemapb.FieldSchema, insertMsg *msgstream.InsertMsg) (*schemapb.FieldData, error) {
if len(insertMsg.GetPartitionName()) > 0 && !Params.ProxyCfg.SkipPartitionKeyCheck.GetAsBool() {
return nil, errors.New("not support manually specifying the partition names if partition key mode is used")
}
for _, fieldData := range insertMsg.GetFieldsData() {
if fieldData.GetFieldId() == fieldSchema.GetFieldID() {
return fieldData, nil
}
}
return nil, errors.New("partition key not specify when insert")
}
func getCollectionProgress(
ctx context.Context,
queryCoord types.QueryCoordClient,
msgBase *commonpb.MsgBase,
collectionID int64,
) (loadProgress int64, refreshProgress int64, err error) {
resp, err := queryCoord.ShowLoadCollections(ctx, &querypb.ShowCollectionsRequest{
Base: commonpbutil.UpdateMsgBase(
msgBase,
commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections),
),
CollectionIDs: []int64{collectionID},
})
if err != nil {
log.Ctx(ctx).Warn("fail to show collections",
zap.Int64("collectionID", collectionID),
zap.Error(err),
)
return
}
err = merr.Error(resp.GetStatus())
if err != nil {
log.Ctx(ctx).Warn("fail to show collections",
zap.Int64("collectionID", collectionID),
zap.Error(err))
return
}
loadProgress = resp.GetInMemoryPercentages()[0]
if len(resp.GetRefreshProgress()) > 0 { // Compatibility for new Proxy with old QueryCoord
refreshProgress = resp.GetRefreshProgress()[0]
}
return
}
func getPartitionProgress(
ctx context.Context,
queryCoord types.QueryCoordClient,
msgBase *commonpb.MsgBase,
partitionNames []string,
collectionName string,
collectionID int64,
dbName string,
) (loadProgress int64, refreshProgress int64, err error) {
IDs2Names := make(map[int64]string)
partitionIDs := make([]int64, 0)
for _, partitionName := range partitionNames {
var partitionID int64
partitionID, err = globalMetaCache.GetPartitionID(ctx, dbName, collectionName, partitionName)
if err != nil {
return
}
IDs2Names[partitionID] = partitionName
partitionIDs = append(partitionIDs, partitionID)
}
var resp *querypb.ShowPartitionsResponse
resp, err = queryCoord.ShowLoadPartitions(ctx, &querypb.ShowPartitionsRequest{
Base: commonpbutil.UpdateMsgBase(
msgBase,
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
),
CollectionID: collectionID,
PartitionIDs: partitionIDs,
})
if err != nil {
log.Ctx(ctx).Warn("fail to show partitions", zap.Int64("collection_id", collectionID),
zap.String("collection_name", collectionName),
zap.Strings("partition_names", partitionNames),
zap.Error(err))
return
}
err = merr.Error(resp.GetStatus())
if err != nil {
err = merr.Error(resp.GetStatus())
log.Ctx(ctx).Warn("fail to show partitions",
zap.String("collectionName", collectionName),
zap.Strings("partitionNames", partitionNames),
zap.Error(err))
return
}
for _, p := range resp.InMemoryPercentages {
loadProgress += p
}
loadProgress /= int64(len(partitionIDs))
if len(resp.GetRefreshProgress()) > 0 { // Compatibility for new Proxy with old QueryCoord
refreshProgress = resp.GetRefreshProgress()[0]
}
return
}
func isPartitionKeyMode(ctx context.Context, dbName string, colName string) (bool, error) {
colSchema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, colName)
if err != nil {
return false, err
}
for _, fieldSchema := range colSchema.GetFields() {
if fieldSchema.IsPartitionKey {
return true, nil
}
}
return false, nil
}
func hasPartitionKeyModeField(schema *schemapb.CollectionSchema) bool {
for _, fieldSchema := range schema.GetFields() {
if fieldSchema.IsPartitionKey {
return true
}
}
return false
}
// getDefaultPartitionsInPartitionKeyMode only used in partition key mode
func getDefaultPartitionsInPartitionKeyMode(ctx context.Context, dbName string, collectionName string) ([]string, error) {
partitions, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
if err != nil {
return nil, err
}
// Make sure the order of the partition names got every time is the same
partitionNames, _, err := typeutil.RearrangePartitionsForPartitionKey(partitions)
if err != nil {
return nil, err
}
return partitionNames, nil
}
func assignChannelsByPK(pks *schemapb.IDs, channelNames []string, insertMsg *msgstream.InsertMsg) map[string][]int {
insertMsg.HashValues = typeutil.HashPK2Channels(pks, channelNames)
// groupedHashKeys represents the dmChannel index
channel2RowOffsets := make(map[string][]int) // channelName to count
// assert len(it.hashValues) < maxInt
for offset, channelID := range insertMsg.HashValues {
channelName := channelNames[channelID]
if _, ok := channel2RowOffsets[channelName]; !ok {
channel2RowOffsets[channelName] = []int{}
}
channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset)
}
return channel2RowOffsets
}
func assignPartitionKeys(ctx context.Context, dbName string, collName string, keys []*planpb.GenericValue) ([]string, error) {
partitionNames, err := globalMetaCache.GetPartitionsIndex(ctx, dbName, collName)
if err != nil {
return nil, err
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, collName)
if err != nil {
return nil, err
}
partitionKeyFieldSchema, err := typeutil.GetPartitionKeyFieldSchema(schema.CollectionSchema)
if err != nil {
return nil, err
}
hashedPartitionNames, err := typeutil2.HashKey2Partitions(partitionKeyFieldSchema, keys, partitionNames)
return hashedPartitionNames, err
}
func ErrWithLog(logger *log.MLogger, msg string, err error) error {
wrapErr := errors.Wrap(err, msg)
if logger != nil {
logger.Warn(msg, zap.Error(err))
return wrapErr
}
log.Warn(msg, zap.Error(err))
return wrapErr
}
func verifyDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
for _, field := range insertMsg.FieldsData {
if field.GetFieldName() == common.MetaFieldName {
if !schema.EnableDynamicField {
return fmt.Errorf("without dynamic schema enabled, the field name cannot be set to %s", common.MetaFieldName)
}
for _, rowData := range field.GetScalars().GetJsonData().GetData() {
jsonData := make(map[string]interface{})
if err := json.Unmarshal(rowData, &jsonData); err != nil {
log.Info("insert invalid dynamic data, milvus only support json map",
zap.ByteString("data", rowData),
zap.Error(err),
)
return merr.WrapErrIoFailedReason(err.Error())
}
if _, ok := jsonData[common.MetaFieldName]; ok {
return fmt.Errorf("cannot set json key to: %s", common.MetaFieldName)
}
for _, f := range schema.GetFields() {
if _, ok := jsonData[f.GetName()]; ok {
log.Info("dynamic field name include the static field name", zap.String("fieldName", f.GetName()))
return fmt.Errorf("dynamic field name cannot include the static field name: %s", f.GetName())
}
}
}
}
}
return nil
}
func checkDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
for _, data := range insertMsg.FieldsData {
if data.IsDynamic {
data.FieldName = common.MetaFieldName
return verifyDynamicFieldData(schema, insertMsg)
}
}
defaultData := make([][]byte, insertMsg.NRows())
for i := range defaultData {
defaultData[i] = []byte("{}")
}
dynamicData := autoGenDynamicFieldData(defaultData)
insertMsg.FieldsData = append(insertMsg.FieldsData, dynamicData)
return nil
}
func addNamespaceData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
namespaceEnabeld, _, err := common.ParseNamespaceProp(schema.Properties...)
if err != nil {
return err
}
namespaceIsSet := insertMsg.InsertRequest.Namespace != nil
if namespaceEnabeld != namespaceIsSet {
if namespaceIsSet {
return fmt.Errorf("namespace data is set but namespace disabled")
}
return fmt.Errorf("namespace data is not set but namespace enabled")
}
if !namespaceEnabeld {
return nil
}
// check namespace field exists
namespaceField := typeutil.GetFieldByName(schema, common.NamespaceFieldName)
if namespaceField == nil {
return fmt.Errorf("namespace field not found")
}
// check namespace field data is already set
for _, fieldData := range insertMsg.FieldsData {
if fieldData.FieldId == namespaceField.FieldID {
return fmt.Errorf("namespace field data is already set by users")
}
}
// set namespace field data
namespaceData := make([]string, insertMsg.NRows())
namespace := *insertMsg.InsertRequest.Namespace
for i := range namespaceData {
namespaceData[i] = namespace
}
insertMsg.FieldsData = append(insertMsg.FieldsData, &schemapb.FieldData{
FieldName: namespaceField.Name,
FieldId: namespaceField.FieldID,
Type: namespaceField.DataType,
Field: &schemapb.FieldData_Scalars{
Scalars: &schemapb.ScalarField{
Data: &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: namespaceData,
},
},
},
},
})
return nil
}
func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) {
if globalMetaCache != nil {
return globalMetaCache.GetCollectionSchema(ctx, dbName, colName)
}
return nil, merr.WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), "initialization")
}
func CheckDatabase(ctx context.Context, dbName string) bool {
if globalMetaCache != nil {
return globalMetaCache.HasDatabase(ctx, dbName)
}
return false
}
func SetReportValue(status *commonpb.Status, value int) {
if value <= 0 {
return
}
if !merr.Ok(status) {
return
}
if status.ExtraInfo == nil {
status.ExtraInfo = make(map[string]string)
}
status.ExtraInfo["report_value"] = strconv.Itoa(value)
}
func SetStorageCost(status *commonpb.Status, storageCost segcore.StorageCost) {
if !Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() {
return
}
if storageCost.ScannedTotalBytes <= 0 {
return
}
if !merr.Ok(status) {
return
}
if status.ExtraInfo == nil {
status.ExtraInfo = make(map[string]string)
// set report_value to 0 for compatibility, when extra info is not nil, there are always the default report_value
// see https://github.com/milvus-io/pymilvus/pull/2999, pymilvus didn't check the report_value is set and use the value
status.ExtraInfo["report_value"] = strconv.Itoa(0)
}
status.ExtraInfo["scanned_remote_bytes"] = strconv.FormatInt(storageCost.ScannedRemoteBytes, 10)
status.ExtraInfo["scanned_total_bytes"] = strconv.FormatInt(storageCost.ScannedTotalBytes, 10)
cacheHitRatio := float64(storageCost.ScannedTotalBytes-storageCost.ScannedRemoteBytes) / float64(storageCost.ScannedTotalBytes)
status.ExtraInfo["cache_hit_ratio"] = strconv.FormatFloat(cacheHitRatio, 'f', -1, 64)
}
func GetCostValue(status *commonpb.Status) int {
if status == nil || status.ExtraInfo == nil {
return 0
}
value, err := strconv.Atoi(status.ExtraInfo["report_value"])
if err != nil {
return 0
}
return value
}
// final return value means value is valid or not
func GetStorageCost(status *commonpb.Status) (int64, int64, float64, bool) {
if status == nil || status.ExtraInfo == nil {
return 0, 0, 0, false
}
var scannedRemoteBytes int64
var scannedTotalBytes int64
var cacheHitRatio float64
var err error
if value, ok := status.ExtraInfo["scanned_remote_bytes"]; ok {
scannedRemoteBytes, err = strconv.ParseInt(value, 10, 64)
if err != nil {
log.Warn("scanned_remote_bytes is not a valid int64", zap.String("value", value), zap.Error(err))
return 0, 0, 0, false
}
} else {
return 0, 0, 0, false
}
if value, ok := status.ExtraInfo["scanned_total_bytes"]; ok {
scannedTotalBytes, err = strconv.ParseInt(value, 10, 64)
if err != nil {
log.Warn("scanned_total_bytes is not a valid int64", zap.String("value", value), zap.Error(err))
return 0, 0, 0, false
}
} else {
return 0, 0, 0, false
}
if value, ok := status.ExtraInfo["cache_hit_ratio"]; ok {
cacheHitRatio, err = strconv.ParseFloat(value, 64)
if err != nil {
log.Warn("cache_hit_ratio is not a valid float64", zap.String("value", value), zap.Error(err))
return 0, 0, 0, false
}
} else {
return 0, 0, 0, false
}
return scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, true
}
// GetRequestInfo returns collection name and rateType of request and return tokens needed.
func GetRequestInfo(ctx context.Context, req proto.Message) (int64, map[int64][]int64, internalpb.RateType, int, error) {
switch r := req.(type) {
case *milvuspb.InsertRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
case *milvuspb.UpsertRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
case *milvuspb.DeleteRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
case *milvuspb.ImportRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
case *milvuspb.SearchRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
case *milvuspb.QueryRequest:
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
case *milvuspb.CreateCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.DropCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.LoadCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.ReleaseCollectionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
case *milvuspb.CreatePartitionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.DropPartitionRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.LoadPartitionsRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.ReleasePartitionsRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
case *milvuspb.CreateIndexRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.DropIndexRequest:
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
case *milvuspb.FlushRequest:
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs := make(map[int64][]int64, 0)
for _, collectionName := range r.GetCollectionNames() {
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
collToPartIDs[collectionID] = []int64{}
}
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
case *milvuspb.ManualCompactionRequest:
dbName := GetCurDBNameFromContextOrDefault(ctx)
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
if err != nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
}
return dbInfo.dbID, map[int64][]int64{
r.GetCollectionID(): {},
}, internalpb.RateType_DDLCompaction, 1, nil
case *milvuspb.CreateDatabaseRequest:
log.Info("rate limiter CreateDatabaseRequest")
return util.InvalidDBID, map[int64][]int64{}, internalpb.RateType_DDLDB, 1, nil
case *milvuspb.DropDatabaseRequest:
log.Info("rate limiter DropDatabaseRequest")
return util.InvalidDBID, map[int64][]int64{}, internalpb.RateType_DDLDB, 1, nil
case *milvuspb.AlterDatabaseRequest:
return util.InvalidDBID, map[int64][]int64{}, internalpb.RateType_DDLDB, 1, nil
default: // TODO: support more request
if req == nil {
return util.InvalidDBID, map[int64][]int64{}, 0, 0, errors.New("null request")
}
log.RatedWarn(60, "not supported request type for rate limiter", zap.String("type", reflect.TypeOf(req).String()))
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
}
}
// GetFailedResponse returns failed response.
func GetFailedResponse(req any, err error) any {
switch req.(type) {
case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest:
return failedMutationResult(err)
case *milvuspb.ImportRequest:
return &milvuspb.ImportResponse{
Status: merr.Status(err),
}
case *milvuspb.SearchRequest:
return &milvuspb.SearchResults{
Status: merr.Status(err),
}
case *milvuspb.QueryRequest:
return &milvuspb.QueryResults{
Status: merr.Status(err),
}
case *milvuspb.CreateCollectionRequest, *milvuspb.DropCollectionRequest,
*milvuspb.LoadCollectionRequest, *milvuspb.ReleaseCollectionRequest,
*milvuspb.CreatePartitionRequest, *milvuspb.DropPartitionRequest,
*milvuspb.LoadPartitionsRequest, *milvuspb.ReleasePartitionsRequest,
*milvuspb.CreateIndexRequest, *milvuspb.DropIndexRequest,
*milvuspb.CreateDatabaseRequest, *milvuspb.DropDatabaseRequest,
*milvuspb.AlterDatabaseRequest:
return merr.Status(err)
case *milvuspb.FlushRequest:
return &milvuspb.FlushResponse{
Status: merr.Status(err),
}
case *milvuspb.ManualCompactionRequest:
return &milvuspb.ManualCompactionResponse{
Status: merr.Status(err),
}
}
return nil
}
func GetReplicateID(ctx context.Context, database, collectionName string) (string, error) {
if globalMetaCache == nil {
return "", merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
}
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, database, collectionName, 0)
if err != nil {
return "", err
}
if colInfo.replicateID != "" {
return colInfo.replicateID, nil
}
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, database)
if err != nil {
return "", err
}
replicateID, _ := common.GetReplicateID(dbInfo.properties)
return replicateID, nil
}
func IsBM25FunctionOutputField(field *schemapb.FieldSchema, collSchema *schemapb.CollectionSchema) bool {
if !(field.GetIsFunctionOutput() && field.GetDataType() == schemapb.DataType_SparseFloatVector) {
return false
}
for _, fSchema := range collSchema.Functions {
if fSchema.Type == schemapb.FunctionType_BM25 {
if len(fSchema.OutputFieldNames) != 0 && field.Name == fSchema.OutputFieldNames[0] {
return true
}
if len(fSchema.OutputFieldIds) != 0 && field.FieldID == fSchema.OutputFieldIds[0] {
return true
}
}
}
return false
}
func GetFunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
fields := make([]string, 0)
for _, fSchema := range collSchema.Functions {
fields = append(fields, fSchema.OutputFieldNames...)
}
return fields
}
func GetBM25FunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
fields := make([]string, 0)
for _, fSchema := range collSchema.Functions {
if fSchema.Type == schemapb.FunctionType_BM25 {
fields = append(fields, fSchema.OutputFieldNames...)
}
}
return fields
}
func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
properties := make(map[string]string)
for _, pair := range pairs {
properties[pair.Key] = pair.Value
}
v, ok := properties[common.CollectionTTLConfigKey]
if ok {
ttl, err := strconv.Atoi(v)
if err != nil {
return 0
}
return uint64(time.Duration(ttl) * time.Second)
}
return 0
}
// reconstructStructFieldDataCommon reconstructs struct fields from flattened sub-fields
// It works with both QueryResults and SearchResults by operating on the common data structures
func reconstructStructFieldDataCommon(
fieldsData []*schemapb.FieldData,
outputFields []string,
schema *schemapb.CollectionSchema,
) ([]*schemapb.FieldData, []string) {
if len(outputFields) == 1 && outputFields[0] == "count(*)" {
return fieldsData, outputFields
}
if len(schema.StructArrayFields) == 0 {
return fieldsData, outputFields
}
regularFieldIDs := make(map[int64]interface{})
subFieldToStructMap := make(map[int64]int64)
groupedStructFields := make(map[int64][]*schemapb.FieldData)
structFieldNames := make(map[int64]string)
reconstructedOutputFields := make([]string, 0, len(fieldsData))
// record all regular field IDs
for _, field := range schema.Fields {
regularFieldIDs[field.GetFieldID()] = nil
}
// build the mapping from sub-field ID to struct field ID
for _, structField := range schema.StructArrayFields {
for _, subField := range structField.GetFields() {
subFieldToStructMap[subField.GetFieldID()] = structField.GetFieldID()
}
structFieldNames[structField.GetFieldID()] = structField.GetName()
}
newFieldsData := make([]*schemapb.FieldData, 0, len(fieldsData))
for _, field := range fieldsData {
fieldID := field.GetFieldId()
if _, ok := regularFieldIDs[fieldID]; ok {
newFieldsData = append(newFieldsData, field)
reconstructedOutputFields = append(reconstructedOutputFields, field.GetFieldName())
} else {
structFieldID := subFieldToStructMap[fieldID]
groupedStructFields[structFieldID] = append(groupedStructFields[structFieldID], field)
}
}
for structFieldID, fields := range groupedStructFields {
// Create deep copies of fields to avoid modifying original data
// and restore original field names for user-facing response
copiedFields := make([]*schemapb.FieldData, len(fields))
for i, field := range fields {
copiedFields[i] = proto.Clone(field).(*schemapb.FieldData)
// Extract original field name from structName[fieldName] format
originalName, err := extractOriginalFieldName(copiedFields[i].FieldName)
if err != nil {
// This should not happen in normal operation - indicates a bug
log.Error("failed to extract original field name from struct field",
zap.String("fieldName", copiedFields[i].FieldName),
zap.Error(err))
// Keep the transformed name to avoid data corruption
} else {
copiedFields[i].FieldName = originalName
}
}
fieldData := &schemapb.FieldData{
FieldName: structFieldNames[structFieldID],
FieldId: structFieldID,
Type: schemapb.DataType_ArrayOfStruct,
Field: &schemapb.FieldData_StructArrays{StructArrays: &schemapb.StructArrayField{Fields: copiedFields}},
}
newFieldsData = append(newFieldsData, fieldData)
reconstructedOutputFields = append(reconstructedOutputFields, structFieldNames[structFieldID])
}
return newFieldsData, reconstructedOutputFields
}
// Wrapper for QueryResults
func reconstructStructFieldDataForQuery(results *milvuspb.QueryResults, schema *schemapb.CollectionSchema) {
fieldsData, outputFields := reconstructStructFieldDataCommon(
results.FieldsData,
results.OutputFields,
schema,
)
results.FieldsData = fieldsData
results.OutputFields = outputFields
}
// New wrapper for SearchResults
func reconstructStructFieldDataForSearch(results *milvuspb.SearchResults, schema *schemapb.CollectionSchema) {
if results.Results == nil {
return
}
fieldsData, outputFields := reconstructStructFieldDataCommon(
results.Results.FieldsData,
results.Results.OutputFields,
schema,
)
results.Results.FieldsData = fieldsData
results.Results.OutputFields = outputFields
}
func hasTimestamptzField(schema *schemapb.CollectionSchema) bool {
for _, field := range schema.Fields {
if field.GetDataType() == schemapb.DataType_Timestamptz {
return true
}
}
return false
}
func getDefaultTimezoneVal(props ...*commonpb.KeyValuePair) (bool, string) {
for _, p := range props {
// used in collection or database
if p.GetKey() == common.DatabaseDefaultTimezone || p.GetKey() == common.CollectionDefaultTimezone {
return true, p.Value
}
}
return false, ""
}
func checkTimezone(props ...*commonpb.KeyValuePair) error {
hasTImezone, timezoneStr := getDefaultTimezoneVal(props...)
if hasTImezone {
_, err := time.LoadLocation(timezoneStr)
if err != nil {
return merr.WrapErrParameterInvalidMsg("invalid timezone, should be a IANA timezone name: %s", err.Error())
}
}
return nil
}
func getColTimezone(colInfo *collectionInfo) (bool, string) {
return getDefaultTimezoneVal(colInfo.properties...)
}
func getDbTimezone(dbInfo *databaseInfo) (bool, string) {
return getDefaultTimezoneVal(dbInfo.properties...)
}
func timestamptzIsoStr2Utc(columns []*schemapb.FieldData, colTimezone string) error {
naiveLayouts := []string{
"2006-01-02T15:04:05.999999999",
"2006-01-02T15:04:05",
"2006-01-02 15:04:05.999999999",
"2006-01-02 15:04:05",
}
for _, fieldData := range columns {
if fieldData.GetType() != schemapb.DataType_Timestamptz {
continue
}
scalarField := fieldData.GetScalars()
if scalarField == nil || scalarField.GetStringData() == nil {
log.Warn("field data is not string data", zap.String("fieldName", fieldData.GetFieldName()))
return merr.WrapErrParameterInvalidMsg("field data is not string data")
}
stringData := scalarField.GetStringData().GetData()
utcTimestamps := make([]int64, len(stringData))
for i, isoStr := range stringData {
var t time.Time
var err error
// parse directly
t, err = time.Parse(time.RFC3339Nano, isoStr)
if err == nil {
utcTimestamps[i] = t.UnixMicro()
continue
}
// no timezone, try to find timezone in collecion -> database level
defaultTZ := "UTC"
if colTimezone != "" {
defaultTZ = colTimezone
}
location, err := time.LoadLocation(defaultTZ)
if err != nil {
log.Error("invalid timezone", zap.String("timezone", defaultTZ), zap.Error(err))
return merr.WrapErrParameterInvalidMsg("got invalid default timezone: %s", defaultTZ)
}
var parsed bool
for _, layout := range naiveLayouts {
t, err = time.ParseInLocation(layout, isoStr, location)
if err == nil {
parsed = true
break
}
}
if !parsed {
log.Warn("Can not parse timestamptz string", zap.String("timestamp_string", isoStr))
return merr.WrapErrParameterInvalidMsg("got invalid timestamptz string: %s", isoStr)
}
utcTimestamps[i] = t.UnixMicro()
}
// Replace data in place
fieldData.GetScalars().Data = &schemapb.ScalarField_TimestamptzData{
TimestamptzData: &schemapb.TimestamptzArray{
Data: utcTimestamps,
},
}
}
return nil
}
func timestamptzUTC2IsoStr(results []*schemapb.FieldData, userDefineTimezone string, colTimezone string) error {
// Determine the target timezone based on priority: collection -> database -> UTC.
defaultTZ := "UTC"
if userDefineTimezone != "" {
defaultTZ = userDefineTimezone
} else if colTimezone != "" {
defaultTZ = colTimezone
}
location, err := time.LoadLocation(defaultTZ)
if err != nil {
log.Error("invalid timezone", zap.String("timezone", defaultTZ), zap.Error(err))
return merr.WrapErrParameterInvalidMsg("got invalid default timezone: %s", defaultTZ)
}
for _, fieldData := range results {
if fieldData.GetType() != schemapb.DataType_Timestamptz {
continue
}
scalarField := fieldData.GetScalars()
if scalarField == nil || scalarField.GetTimestamptzData() == nil {
if longData := scalarField.GetLongData(); longData != nil && len(longData.GetData()) > 0 {
log.Warn("field data is not Timestamptz data", zap.String("fieldName", fieldData.GetFieldName()))
return merr.WrapErrParameterInvalidMsg("field data for '%s' is not Timestamptz data", fieldData.GetFieldName())
}
}
utcTimestamps := scalarField.GetTimestamptzData().GetData()
isoStrings := make([]string, len(utcTimestamps))
for i, ts := range utcTimestamps {
t := time.UnixMicro(ts).UTC()
localTime := t.In(location)
isoStrings[i] = localTime.Format(time.RFC3339Nano)
}
// Replace the TimestamptzData with the new StringData in place.
fieldData.GetScalars().Data = &schemapb.ScalarField_StringData{
StringData: &schemapb.StringArray{
Data: isoStrings,
},
}
}
return nil
}
// extractFields is a helper function to extract specific integer fields from a time.Time object.
// Supported fields are: "year", "month", "day", "hour", "minute", "second", "microsecond", "nanosecond".
func extractFields(t time.Time, fieldList []string) ([]int64, error) {
extractedValues := make([]int64, 0, len(fieldList))
for _, field := range fieldList {
var val int64
switch strings.ToLower(field) {
case common.TszYear:
val = int64(t.Year())
case common.TszMonth:
val = int64(t.Month())
case common.TszDay:
val = int64(t.Day())
case common.TszHour:
val = int64(t.Hour())
case common.TszMinute:
val = int64(t.Minute())
case common.TszSecond:
val = int64(t.Second())
case common.TszMicrosecond:
val = int64(t.Nanosecond() / 1000)
default:
return nil, merr.WrapErrParameterInvalidMsg("unsupported field for extraction: %s, fields should be seprated by ',' or ' '", field)
}
extractedValues = append(extractedValues, val)
}
return extractedValues, nil
}
func extractFieldsFromResults(results []*schemapb.FieldData, precedenceTimezone []string, fieldList []string) error {
var targetLocation *time.Location
for _, tz := range precedenceTimezone {
if tz != "" {
loc, err := time.LoadLocation(tz)
if err != nil {
log.Error("invalid timezone provided in precedence list", zap.String("timezone", tz), zap.Error(err))
return merr.WrapErrParameterInvalidMsg("got invalid timezone: %s", tz)
}
targetLocation = loc
break // Use the first valid timezone found.
}
}
if targetLocation == nil {
targetLocation = time.UTC
}
for _, fieldData := range results {
if fieldData.GetType() != schemapb.DataType_Timestamptz {
continue
}
scalarField := fieldData.GetScalars()
if scalarField == nil || scalarField.GetTimestamptzData() == nil {
if longData := scalarField.GetLongData(); longData != nil && len(longData.GetData()) > 0 {
log.Warn("field data is not Timestamptz data, but found LongData instead", zap.String("fieldName", fieldData.GetFieldName()))
return merr.WrapErrParameterInvalidMsg("field data for '%s' is not Timestamptz data", fieldData.GetFieldName())
}
continue
}
utcTimestamps := scalarField.GetTimestamptzData().GetData()
extractedResults := make([]*schemapb.ScalarField, 0, len(fieldList))
for _, ts := range utcTimestamps {
t := time.UnixMicro(ts).UTC()
localTime := t.In(targetLocation)
values, err := extractFields(localTime, fieldList)
if err != nil {
return err
}
valuesScalarField := &schemapb.ScalarField_LongData{
LongData: &schemapb.LongArray{
Data: values,
},
}
extractedResults = append(extractedResults, &schemapb.ScalarField{
Data: valuesScalarField,
})
}
fieldData.GetScalars().Data = &schemapb.ScalarField_ArrayData{
ArrayData: &schemapb.ArrayArray{
Data: extractedResults,
ElementType: schemapb.DataType_Int64,
},
}
fieldData.Type = schemapb.DataType_Array
}
return nil
}
func genFunctionFields(ctx context.Context, insertMsg *msgstream.InsertMsg, schema *schemaInfo, partialUpdate bool) error {
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
fieldIDs := lo.Map(insertMsg.FieldsData, func(fieldData *schemapb.FieldData, _ int) int64 {
id, _ := schema.MapFieldID(fieldData.FieldName)
return id
})
// Since PartialUpdate is supported, the field_data here may not be complete
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, partialUpdate)
if err != nil {
log.Ctx(ctx).Warn("Check upsert field error,", zap.String("collectionName", schema.Name), zap.Error(err))
return err
}
if embedding.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-genFunctionFields-call-function-udf")
defer sp.End()
exec, err := embedding.NewFunctionExecutor(schema.CollectionSchema, needProcessFunctions)
if err != nil {
return err
}
sp.AddEvent("Create-function-udf")
if err := exec.ProcessInsert(ctx, insertMsg); err != nil {
return err
}
sp.AddEvent("Call-function-udf")
}
return nil
}