mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
After https://github.com/milvus-io/milvus/pull/44557, the field name in STRUCT field becomes STRUCT_NAME[FIELD_NAME] This PR make import consider the change. issue: https://github.com/milvus-io/milvus/issues/45006 ref: https://github.com/milvus-io/milvus/issues/42148 TODO: parquet is much more complex than csv/json, and I will leave it to a separate PR. --------- Signed-off-by: SpadeA <tangchenjie1210@gmail.com>
3011 lines
102 KiB
Go
3011 lines
102 KiB
Go
// Licensed to the LF AI & Data foundation under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/samber/lo"
|
|
"go.opentelemetry.io/otel"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
"github.com/milvus-io/milvus/internal/json"
|
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
|
"github.com/milvus-io/milvus/internal/proxy/privilege"
|
|
"github.com/milvus-io/milvus/internal/types"
|
|
"github.com/milvus-io/milvus/internal/util/function/embedding"
|
|
"github.com/milvus-io/milvus/internal/util/hookutil"
|
|
"github.com/milvus-io/milvus/internal/util/indexparamcheck"
|
|
"github.com/milvus-io/milvus/internal/util/segcore"
|
|
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
|
"github.com/milvus-io/milvus/pkg/v2/mq/msgstream"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/contextutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/crypto"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/metric"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
const (
|
|
strongTS = 0
|
|
boundedTS = 2
|
|
|
|
// enableMultipleVectorFields indicates whether to enable multiple vector fields.
|
|
enableMultipleVectorFields = true
|
|
|
|
defaultMaxArrayCapacity = 4096
|
|
|
|
defaultMaxSearchRequest = 1024
|
|
|
|
// DefaultArithmeticIndexType name of default index type for scalar field
|
|
DefaultArithmeticIndexType = indexparamcheck.IndexINVERTED
|
|
|
|
// DefaultStringIndexType name of default index type for varChar/string field
|
|
DefaultStringIndexType = indexparamcheck.IndexINVERTED
|
|
)
|
|
|
|
var logger = log.L().WithOptions(zap.Fields(zap.String("role", typeutil.ProxyRole)))
|
|
|
|
// transformStructFieldNames transforms struct field names to structName[fieldName] format
|
|
// This ensures global uniqueness while allowing same field names across different structs
|
|
func transformStructFieldNames(schema *schemapb.CollectionSchema) error {
|
|
for _, structArrayField := range schema.StructArrayFields {
|
|
structName := structArrayField.Name
|
|
for _, field := range structArrayField.Fields {
|
|
// Create transformed name: structName[fieldName]
|
|
newName := typeutil.ConcatStructFieldName(structName, field.Name)
|
|
field.Name = newName
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// restoreStructFieldNames restores original field names from structName[fieldName] format
|
|
// This is used when returning schema information to users (e.g., in describe collection)
|
|
func restoreStructFieldNames(schema *schemapb.CollectionSchema) error {
|
|
for _, structArrayField := range schema.StructArrayFields {
|
|
structName := structArrayField.Name
|
|
expectedPrefix := structName + "["
|
|
|
|
for _, field := range structArrayField.Fields {
|
|
if strings.HasPrefix(field.Name, expectedPrefix) && strings.HasSuffix(field.Name, "]") {
|
|
// Extract fieldName: remove "structName[" prefix and "]" suffix
|
|
field.Name = field.Name[len(expectedPrefix) : len(field.Name)-1]
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// extractOriginalFieldName extracts the original field name from structName[fieldName] format
|
|
// This function should only be called on transformed struct field names
|
|
func extractOriginalFieldName(transformedName string) (string, error) {
|
|
idx := strings.Index(transformedName, "[")
|
|
if idx == -1 {
|
|
return "", fmt.Errorf("not a transformed struct field name: %s", transformedName)
|
|
}
|
|
|
|
if !strings.HasSuffix(transformedName, "]") {
|
|
return "", fmt.Errorf("invalid struct field format: %s, missing closing bracket", transformedName)
|
|
}
|
|
|
|
if idx == 0 {
|
|
return "", fmt.Errorf("invalid struct field format: %s, missing struct name", transformedName)
|
|
}
|
|
|
|
fieldName := transformedName[idx+1 : len(transformedName)-1]
|
|
if fieldName == "" {
|
|
return "", fmt.Errorf("invalid struct field format: %s, empty field name", transformedName)
|
|
}
|
|
|
|
return fieldName, nil
|
|
}
|
|
|
|
// isAlpha check if c is alpha.
|
|
func isAlpha(c uint8) bool {
|
|
if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// isNumber check if c is a number.
|
|
func isNumber(c uint8) bool {
|
|
if c < '0' || c > '9' {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
// check run analyzer params when collection name was set
|
|
func validateRunAnalyzer(req *milvuspb.RunAnalyzerRequest) error {
|
|
if req.GetAnalyzerParams() != "" {
|
|
return fmt.Errorf("run analyzer can't use analyzer params and (collection,field) in same time")
|
|
}
|
|
|
|
if req.GetFieldName() == "" {
|
|
return fmt.Errorf("must set field name when collection name was set")
|
|
}
|
|
|
|
if req.GetAnalyzerNames() != nil {
|
|
if len(req.GetAnalyzerNames()) != 1 && len(req.GetAnalyzerNames()) != len(req.GetPlaceholder()) {
|
|
return fmt.Errorf("only support set one analyzer name for all text or set analyzer name for each text, but now analzer name num: %d, text num: %d",
|
|
len(req.GetAnalyzerNames()), len(req.GetPlaceholder()))
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateMaxQueryResultWindow(offset int64, limit int64) error {
|
|
if offset < 0 {
|
|
return fmt.Errorf("%s [%d] is invalid, should be gte than 0", OffsetKey, offset)
|
|
}
|
|
if limit <= 0 {
|
|
return fmt.Errorf("%s [%d] is invalid, should be greater than 0", LimitKey, limit)
|
|
}
|
|
|
|
depth := offset + limit
|
|
maxQueryResultWindow := Params.QuotaConfig.MaxQueryResultWindow.GetAsInt64()
|
|
if depth <= 0 || depth > maxQueryResultWindow {
|
|
return fmt.Errorf("(offset+limit) should be in range [1, %d], but got %d", maxQueryResultWindow, depth)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateLimit(limit int64) error {
|
|
topKLimit := Params.QuotaConfig.TopKLimit.GetAsInt64()
|
|
if limit <= 0 || limit > topKLimit {
|
|
return fmt.Errorf("it should be in range [1, %d], but got %d", topKLimit, limit)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateNQLimit(limit int64) error {
|
|
nqLimit := Params.QuotaConfig.NQLimit.GetAsInt64()
|
|
if limit <= 0 || limit > nqLimit {
|
|
return fmt.Errorf("nq (number of search vector per search request) should be in range [1, %d], but got %d", nqLimit, limit)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateCollectionNameOrAlias(entity, entityType string) error {
|
|
if entity == "" {
|
|
return merr.WrapErrParameterInvalidMsg("collection %s should not be empty", entityType)
|
|
}
|
|
|
|
invalidMsg := fmt.Sprintf("Invalid collection %s: %s. ", entityType, entity)
|
|
if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
return merr.WrapErrParameterInvalidMsg("%s the length of a collection %s must be less than %s characters", invalidMsg, entityType,
|
|
Params.ProxyCfg.MaxNameLength.GetValue())
|
|
}
|
|
|
|
firstChar := entity[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) {
|
|
return merr.WrapErrParameterInvalidMsg("%s the first character of a collection %s must be an underscore or letter", invalidMsg, entityType)
|
|
}
|
|
|
|
for i := 1; i < len(entity); i++ {
|
|
c := entity[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) {
|
|
return merr.WrapErrParameterInvalidMsg("%s collection %s can only contain numbers, letters and underscores", invalidMsg, entityType)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidatePrivilegeGroupName(groupName string) error {
|
|
if groupName == "" {
|
|
return merr.WrapErrPrivilegeGroupNameInvalid("privilege group name should not be empty")
|
|
}
|
|
|
|
if len(groupName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
return merr.WrapErrPrivilegeGroupNameInvalid(
|
|
"the length of a privilege group name %s must be less than %s characters", groupName, Params.ProxyCfg.MaxNameLength.GetValue())
|
|
}
|
|
|
|
firstChar := groupName[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) {
|
|
return merr.WrapErrPrivilegeGroupNameInvalid(
|
|
"the first character of a privilege group name %s must be an underscore or letter", groupName)
|
|
}
|
|
|
|
for i := 1; i < len(groupName); i++ {
|
|
c := groupName[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) {
|
|
return merr.WrapErrParameterInvalidMsg(
|
|
"privilege group name %s can only contain numbers, letters and underscores", groupName)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateResourceGroupName(entity string) error {
|
|
if entity == "" {
|
|
return errors.New("resource group name couldn't be empty")
|
|
}
|
|
|
|
invalidMsg := fmt.Sprintf("Invalid resource group name %s.", entity)
|
|
if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
return merr.WrapErrParameterInvalidMsg("%s the length of a resource group name must be less than %s characters",
|
|
invalidMsg, Params.ProxyCfg.MaxNameLength.GetValue())
|
|
}
|
|
|
|
firstChar := entity[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) {
|
|
return merr.WrapErrParameterInvalidMsg("%s the first character of a resource group name must be an underscore or letter", invalidMsg)
|
|
}
|
|
|
|
for i := 1; i < len(entity); i++ {
|
|
c := entity[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) {
|
|
return merr.WrapErrParameterInvalidMsg("%s resource group name can only contain numbers, letters and underscores", invalidMsg)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateDatabaseName(dbName string) error {
|
|
if dbName == "" {
|
|
return merr.WrapErrDatabaseNameInvalid(dbName, "database name couldn't be empty")
|
|
}
|
|
|
|
if len(dbName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
return merr.WrapErrDatabaseNameInvalid(dbName,
|
|
fmt.Sprintf("the length of a database name must be less than %d characters", Params.ProxyCfg.MaxNameLength.GetAsInt()))
|
|
}
|
|
|
|
firstChar := dbName[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) {
|
|
return merr.WrapErrDatabaseNameInvalid(dbName,
|
|
"the first character of a database name must be an underscore or letter")
|
|
}
|
|
|
|
for i := 1; i < len(dbName); i++ {
|
|
c := dbName[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) {
|
|
return merr.WrapErrDatabaseNameInvalid(dbName,
|
|
"database name can only contain numbers, letters and underscores")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ValidateCollectionAlias returns true if collAlias is a valid alias name for collection, otherwise returns false.
|
|
func ValidateCollectionAlias(collAlias string) error {
|
|
return validateCollectionNameOrAlias(collAlias, "alias")
|
|
}
|
|
|
|
func validateCollectionName(collName string) error {
|
|
return validateCollectionNameOrAlias(collName, "name")
|
|
}
|
|
|
|
func validatePartitionTag(partitionTag string, strictCheck bool) error {
|
|
partitionTag = strings.TrimSpace(partitionTag)
|
|
|
|
invalidMsg := "Invalid partition name: " + partitionTag + ". "
|
|
if partitionTag == "" {
|
|
msg := invalidMsg + "Partition name should not be empty."
|
|
return errors.New(msg)
|
|
}
|
|
if len(partitionTag) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
msg := invalidMsg + "The length of a partition name must be less than " + Params.ProxyCfg.MaxNameLength.GetValue() + " characters."
|
|
return errors.New(msg)
|
|
}
|
|
|
|
if strictCheck {
|
|
firstChar := partitionTag[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) && !isNumber(firstChar) {
|
|
msg := invalidMsg + "The first character of a partition name must be an underscore or letter."
|
|
return errors.New(msg)
|
|
}
|
|
|
|
tagSize := len(partitionTag)
|
|
for i := 1; i < tagSize; i++ {
|
|
c := partitionTag[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) && c != '-' {
|
|
msg := invalidMsg + "Partition name can only contain numbers, letters and underscores."
|
|
return errors.New(msg)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateFieldName(fieldName string) error {
|
|
fieldName = strings.TrimSpace(fieldName)
|
|
|
|
if fieldName == "" {
|
|
return merr.WrapErrFieldNameInvalid(fieldName, "field name should not be empty")
|
|
}
|
|
|
|
invalidMsg := "Invalid field name: " + fieldName + ". "
|
|
if len(fieldName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
msg := invalidMsg + "The length of a field name must be less than " + Params.ProxyCfg.MaxNameLength.GetValue() + " characters."
|
|
return merr.WrapErrFieldNameInvalid(fieldName, msg)
|
|
}
|
|
|
|
firstChar := fieldName[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) {
|
|
msg := invalidMsg + "The first character of a field name must be an underscore or letter."
|
|
return merr.WrapErrFieldNameInvalid(fieldName, msg)
|
|
}
|
|
|
|
fieldNameSize := len(fieldName)
|
|
for i := 1; i < fieldNameSize; i++ {
|
|
c := fieldName[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) {
|
|
msg := invalidMsg + "Field name can only contain numbers, letters, and underscores."
|
|
return merr.WrapErrFieldNameInvalid(fieldName, msg)
|
|
}
|
|
}
|
|
if _, ok := common.FieldNameKeywords[fieldName]; ok {
|
|
msg := invalidMsg + fmt.Sprintf("%s is keyword in milvus.", fieldName)
|
|
return merr.WrapErrFieldNameInvalid(fieldName, msg)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateDimension(field *schemapb.FieldSchema) error {
|
|
exist := false
|
|
var dim int64
|
|
for _, param := range field.TypeParams {
|
|
if param.Key == common.DimKey {
|
|
exist = true
|
|
tmp, err := strconv.ParseInt(param.Value, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dim = tmp
|
|
break
|
|
}
|
|
}
|
|
// for sparse vector field, dim should not be specified
|
|
if typeutil.IsSparseFloatVectorType(field.DataType) {
|
|
if exist {
|
|
return fmt.Errorf("dim should not be specified for sparse vector field %s(%d)", field.GetName(), field.FieldID)
|
|
}
|
|
return nil
|
|
}
|
|
if !exist {
|
|
return errors.Newf("dimension is not defined in field type params of field %s, check type param `dim` for vector field", field.GetName())
|
|
}
|
|
|
|
if dim <= 1 {
|
|
return fmt.Errorf("invalid dimension: %d. should be in range 2 ~ %d", dim, Params.ProxyCfg.MaxDimension.GetAsInt())
|
|
}
|
|
|
|
// for dense vector field, dim will be limited by max_dimension
|
|
if typeutil.IsBinaryVectorType(field.DataType) {
|
|
if dim%8 != 0 {
|
|
return fmt.Errorf("invalid dimension: %d of field %s. binary vector dimension should be multiple of 8. ", dim, field.GetName())
|
|
}
|
|
if dim > Params.ProxyCfg.MaxDimension.GetAsInt64()*8 {
|
|
return fmt.Errorf("invalid dimension: %d of field %s. binary vector dimension should be in range 2 ~ %d", dim, field.GetName(), Params.ProxyCfg.MaxDimension.GetAsInt()*8)
|
|
}
|
|
} else {
|
|
if dim > Params.ProxyCfg.MaxDimension.GetAsInt64() {
|
|
return fmt.Errorf("invalid dimension: %d of field %s. float vector dimension should be in range 2 ~ %d", dim, field.GetName(), Params.ProxyCfg.MaxDimension.GetAsInt())
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateMaxLengthPerRow(collectionName string, field *schemapb.FieldSchema) error {
|
|
exist := false
|
|
for _, param := range field.TypeParams {
|
|
if param.Key != common.MaxLengthKey {
|
|
continue
|
|
}
|
|
|
|
maxLengthPerRow, err := strconv.ParseInt(param.Value, 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var defaultMaxLength int64
|
|
if field.DataType == schemapb.DataType_Text {
|
|
defaultMaxLength = Params.ProxyCfg.MaxTextLength.GetAsInt64()
|
|
} else {
|
|
defaultMaxLength = Params.ProxyCfg.MaxVarCharLength.GetAsInt64()
|
|
}
|
|
|
|
if maxLengthPerRow > defaultMaxLength || maxLengthPerRow <= 0 {
|
|
return merr.WrapErrParameterInvalidMsg("the maximum length specified for the field(%s) should be in (0, %d], but got %d instead", field.GetName(), defaultMaxLength, maxLengthPerRow)
|
|
}
|
|
exist = true
|
|
}
|
|
// if not exist type params max_length, return error
|
|
if !exist {
|
|
return fmt.Errorf("type param(max_length) should be specified for the field(%s) of collection %s", field.GetName(), collectionName)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateMaxCapacityPerRow(collectionName string, field *schemapb.FieldSchema) error {
|
|
exist := false
|
|
for _, param := range field.TypeParams {
|
|
if param.Key != common.MaxCapacityKey {
|
|
continue
|
|
}
|
|
|
|
maxCapacityPerRow, err := strconv.ParseInt(param.Value, 10, 64)
|
|
if err != nil {
|
|
return fmt.Errorf("the value for %s of field %s must be an integer", common.MaxCapacityKey, field.GetName())
|
|
}
|
|
if maxCapacityPerRow > defaultMaxArrayCapacity || maxCapacityPerRow <= 0 {
|
|
return errors.New("the maximum capacity specified for a Array should be in (0, 4096]")
|
|
}
|
|
exist = true
|
|
}
|
|
// if not exist type params max_length, return error
|
|
if !exist {
|
|
return fmt.Errorf("type param(max_capacity) should be specified for array field %s of collection %s", field.GetName(), collectionName)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateVectorFieldMetricType(field *schemapb.FieldSchema) error {
|
|
if !typeutil.IsVectorType(field.DataType) {
|
|
return nil
|
|
}
|
|
for _, params := range field.IndexParams {
|
|
if params.Key == common.MetricTypeKey {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf(`index param "metric_type" is not specified for index float vector %s`, field.GetName())
|
|
}
|
|
|
|
func validateDuplicatedFieldName(schema *schemapb.CollectionSchema) error {
|
|
names := make(map[string]bool)
|
|
validateFieldNames := func(name string) error {
|
|
_, ok := names[name]
|
|
if ok {
|
|
return errors.Newf("duplicated field name %s found", name)
|
|
}
|
|
names[name] = true
|
|
return nil
|
|
}
|
|
for _, field := range schema.Fields {
|
|
if err := validateFieldNames(field.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, structArrayField := range schema.StructArrayFields {
|
|
if err := validateFieldNames(structArrayField.Name); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, field := range structArrayField.Fields {
|
|
if err := validateFieldNames(field.Name); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validateElementType(dataType schemapb.DataType) error {
|
|
switch dataType {
|
|
case schemapb.DataType_Bool, schemapb.DataType_Int8, schemapb.DataType_Int16, schemapb.DataType_Int32,
|
|
schemapb.DataType_Int64, schemapb.DataType_Float, schemapb.DataType_Double, schemapb.DataType_VarChar:
|
|
return nil
|
|
case schemapb.DataType_String:
|
|
return errors.New("string data type not supported yet, please use VarChar type instead")
|
|
case schemapb.DataType_None:
|
|
return errors.New("element data type None is not valid")
|
|
}
|
|
return fmt.Errorf("element type %s is not supported", dataType.String())
|
|
}
|
|
|
|
func validateFieldType(schema *schemapb.CollectionSchema) error {
|
|
for _, field := range schema.GetFields() {
|
|
switch field.GetDataType() {
|
|
case schemapb.DataType_String:
|
|
return errors.New("string data type not supported yet, please use VarChar type instead")
|
|
case schemapb.DataType_None:
|
|
return errors.New("data type None is not valid")
|
|
case schemapb.DataType_Array:
|
|
if err := validateElementType(field.GetElementType()); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
for _, structArrayField := range schema.StructArrayFields {
|
|
for _, field := range structArrayField.Fields {
|
|
if field.GetDataType() != schemapb.DataType_Array && field.GetDataType() != schemapb.DataType_ArrayOfVector {
|
|
return errors.Newf("fields in StructArrayField must be Array or ArrayOfVector, field name = %s, field type = %s",
|
|
field.GetName(), field.GetDataType().String())
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ValidateFieldAutoID call after validatePrimaryKey
|
|
func ValidateFieldAutoID(coll *schemapb.CollectionSchema) error {
|
|
idx := -1
|
|
for i, field := range coll.Fields {
|
|
if field.AutoID {
|
|
if idx != -1 {
|
|
return fmt.Errorf("only one field can speficy AutoID with true, field name = %s, %s", coll.Fields[idx].Name, field.Name)
|
|
}
|
|
idx = i
|
|
if !field.IsPrimaryKey {
|
|
return fmt.Errorf("only primary field can speficy AutoID with true, field name = %s", field.Name)
|
|
}
|
|
}
|
|
}
|
|
for _, structArrayField := range coll.StructArrayFields {
|
|
for _, field := range structArrayField.Fields {
|
|
if field.AutoID {
|
|
return errors.Newf("autoID is not supported for struct field, field name = %s", field.Name)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateField(field *schemapb.FieldSchema, schema *schemapb.CollectionSchema) error {
|
|
// validate field name
|
|
var err error
|
|
if err := validateFieldName(field.Name); err != nil {
|
|
return err
|
|
}
|
|
// validate dense vector field type parameters
|
|
isVectorType := typeutil.IsVectorType(field.DataType)
|
|
if isVectorType {
|
|
err = validateDimension(field)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// valid max length per row parameters
|
|
// if max_length not specified, return error
|
|
if field.DataType == schemapb.DataType_VarChar ||
|
|
(field.GetDataType() == schemapb.DataType_Array && field.GetElementType() == schemapb.DataType_VarChar) {
|
|
err = validateMaxLengthPerRow(schema.Name, field)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// valid max capacity for array per row parameters
|
|
// if max_capacity not specified, return error
|
|
if field.DataType == schemapb.DataType_Array {
|
|
if err = validateMaxCapacityPerRow(schema.Name, field); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if field.DataType == schemapb.DataType_ArrayOfVector {
|
|
return fmt.Errorf("array of vector can only be in the struct array field, field name: %s", field.Name)
|
|
}
|
|
|
|
// TODO should remove the index params in the field schema
|
|
indexParams := funcutil.KeyValuePair2Map(field.GetIndexParams())
|
|
if err = ValidateAutoIndexMmapConfig(isVectorType, indexParams); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateFieldsInStruct(field *schemapb.FieldSchema, schema *schemapb.CollectionSchema) error {
|
|
// validate field name
|
|
var err error
|
|
if err := validateFieldName(field.Name); err != nil {
|
|
return err
|
|
}
|
|
|
|
if field.DataType != schemapb.DataType_Array && field.DataType != schemapb.DataType_ArrayOfVector {
|
|
return fmt.Errorf("Fields in StructArrayField can only be array or array of struct, but field %s is %s", field.Name, field.DataType.String())
|
|
}
|
|
|
|
if field.ElementType == schemapb.DataType_ArrayOfStruct || field.ElementType == schemapb.DataType_ArrayOfVector ||
|
|
field.ElementType == schemapb.DataType_Array {
|
|
return fmt.Errorf("Nested array is not supported %s", field.Name)
|
|
}
|
|
|
|
if field.DataType == schemapb.DataType_Array {
|
|
if err := validateElementType(field.GetElementType()); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
// TODO(SpadeA): only support float vector now
|
|
if field.GetElementType() != schemapb.DataType_FloatVector {
|
|
return fmt.Errorf("Unsupported element type of array field %s, now only float vector is supported", field.Name)
|
|
}
|
|
|
|
// if !typeutil.IsVectorType(field.GetElementType()) {
|
|
// return fmt.Errorf("Inconsistent schema: element type of array field %s is not a vector type", field.Name)
|
|
// }
|
|
err = validateDimension(field)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// valid max length per row parameters
|
|
// if max_length not specified, return error
|
|
if field.ElementType == schemapb.DataType_VarChar {
|
|
err = validateMaxLengthPerRow(schema.Name, field)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// todo(SpadeA): make nullable field in struct array supported
|
|
if field.GetNullable() {
|
|
return fmt.Errorf("nullable is not supported for fields in struct array now, fieldName = %s", field.Name)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateStructArrayField(structArrayField *schemapb.StructArrayFieldSchema, schema *schemapb.CollectionSchema) error {
|
|
if len(structArrayField.Fields) == 0 {
|
|
return fmt.Errorf("struct array field %s has no sub-fields", structArrayField.Name)
|
|
}
|
|
|
|
for _, subField := range structArrayField.Fields {
|
|
if err := ValidateFieldsInStruct(subField, schema); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func validatePrimaryKey(coll *schemapb.CollectionSchema) error {
|
|
idx := -1
|
|
for i, field := range coll.Fields {
|
|
if field.IsPrimaryKey {
|
|
if idx != -1 {
|
|
return fmt.Errorf("there are more than one primary key, field name = %s, %s", coll.Fields[idx].Name, field.Name)
|
|
}
|
|
|
|
// The type of the primary key field can only be int64 and varchar
|
|
if field.DataType != schemapb.DataType_Int64 && field.DataType != schemapb.DataType_VarChar {
|
|
return errors.New("the data type of primary key should be Int64 or VarChar")
|
|
}
|
|
|
|
// varchar field do not support autoID
|
|
// If autoID is required, it is recommended to use int64 field as the primary key
|
|
//if field.DataType == schemapb.DataType_VarChar {
|
|
// if field.AutoID {
|
|
// return errors.New("autoID is not supported when the VarChar field is the primary key")
|
|
// }
|
|
//}
|
|
|
|
idx = i
|
|
}
|
|
}
|
|
if idx == -1 {
|
|
return errors.New("primary key is not specified")
|
|
}
|
|
|
|
for _, structArrayField := range coll.StructArrayFields {
|
|
for _, field := range structArrayField.Fields {
|
|
if field.IsPrimaryKey {
|
|
return errors.Newf("primary key is not supported for struct field, field name = %s", field.Name)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateDynamicField(coll *schemapb.CollectionSchema) error {
|
|
for _, field := range coll.Fields {
|
|
if field.IsDynamic {
|
|
return errors.New("cannot explicitly set a field as a dynamic field")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// RepeatedKeyValToMap transfer the kv pairs to map.
|
|
func RepeatedKeyValToMap(kvPairs []*commonpb.KeyValuePair) (map[string]string, error) {
|
|
resMap := make(map[string]string)
|
|
for _, kv := range kvPairs {
|
|
_, ok := resMap[kv.Key]
|
|
if ok {
|
|
return nil, fmt.Errorf("duplicated param key: %s", kv.Key)
|
|
}
|
|
resMap[kv.Key] = kv.Value
|
|
}
|
|
return resMap, nil
|
|
}
|
|
|
|
// isVector check if dataType belongs to vector type.
|
|
func isVector(dataType schemapb.DataType) (bool, error) {
|
|
switch dataType {
|
|
case schemapb.DataType_Bool, schemapb.DataType_Int8,
|
|
schemapb.DataType_Int16, schemapb.DataType_Int32,
|
|
schemapb.DataType_Int64,
|
|
schemapb.DataType_Float, schemapb.DataType_Double:
|
|
return false, nil
|
|
|
|
case schemapb.DataType_FloatVector, schemapb.DataType_BinaryVector, schemapb.DataType_Float16Vector, schemapb.DataType_BFloat16Vector, schemapb.DataType_SparseFloatVector:
|
|
return true, nil
|
|
}
|
|
|
|
return false, fmt.Errorf("invalid data type: %d", dataType)
|
|
}
|
|
|
|
func validateMetricType(dataType schemapb.DataType, metricTypeStrRaw string) error {
|
|
metricTypeStr := strings.ToUpper(metricTypeStrRaw)
|
|
switch metricTypeStr {
|
|
case metric.L2, metric.IP, metric.COSINE:
|
|
if typeutil.IsFloatVectorType(dataType) {
|
|
return nil
|
|
}
|
|
case metric.JACCARD, metric.HAMMING, metric.SUBSTRUCTURE, metric.SUPERSTRUCTURE, metric.MHJACCARD:
|
|
if dataType == schemapb.DataType_BinaryVector {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("data_type %s mismatch with metric_type %s", dataType.String(), metricTypeStrRaw)
|
|
}
|
|
|
|
func validateFunction(coll *schemapb.CollectionSchema) error {
|
|
nameMap := lo.SliceToMap(coll.GetFields(), func(field *schemapb.FieldSchema) (string, *schemapb.FieldSchema) {
|
|
return field.GetName(), field
|
|
})
|
|
usedOutputField := typeutil.NewSet[string]()
|
|
usedFunctionName := typeutil.NewSet[string]()
|
|
|
|
// reset `IsFunctionOuput` despite any user input, this shall be determined by function def only.
|
|
for _, field := range coll.Fields {
|
|
field.IsFunctionOutput = false
|
|
}
|
|
|
|
for _, function := range coll.GetFunctions() {
|
|
if err := checkFunctionBasicParams(function); err != nil {
|
|
return err
|
|
}
|
|
|
|
if usedFunctionName.Contain(function.GetName()) {
|
|
return fmt.Errorf("duplicate function name: %s", function.GetName())
|
|
}
|
|
|
|
usedFunctionName.Insert(function.GetName())
|
|
inputFields := []*schemapb.FieldSchema{}
|
|
for _, name := range function.GetInputFieldNames() {
|
|
inputField, ok := nameMap[name]
|
|
if !ok {
|
|
return fmt.Errorf("function input field not found: %s", name)
|
|
}
|
|
inputFields = append(inputFields, inputField)
|
|
}
|
|
|
|
if err := checkFunctionInputField(function, inputFields); err != nil {
|
|
return err
|
|
}
|
|
|
|
outputFields := make([]*schemapb.FieldSchema, len(function.GetOutputFieldNames()))
|
|
for i, name := range function.GetOutputFieldNames() {
|
|
outputField, ok := nameMap[name]
|
|
if !ok {
|
|
return fmt.Errorf("function output field not found: %s", name)
|
|
}
|
|
|
|
if outputField.GetIsPrimaryKey() {
|
|
return fmt.Errorf("function output field cannot be primary key: function %s, field %s", function.GetName(), outputField.GetName())
|
|
}
|
|
|
|
if outputField.GetIsPartitionKey() || outputField.GetIsClusteringKey() {
|
|
return fmt.Errorf("function output field cannot be partition key or clustering key: function %s, field %s", function.GetName(), outputField.GetName())
|
|
}
|
|
|
|
if outputField.GetNullable() {
|
|
return fmt.Errorf("function output field cannot be nullable: function %s, field %s", function.GetName(), outputField.GetName())
|
|
}
|
|
|
|
outputField.IsFunctionOutput = true
|
|
outputFields[i] = outputField
|
|
if usedOutputField.Contain(name) {
|
|
return fmt.Errorf("duplicate function output field: function %s, field %s", function.GetName(), name)
|
|
}
|
|
usedOutputField.Insert(name)
|
|
}
|
|
|
|
if err := checkFunctionOutputField(function, outputFields); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := embedding.ValidateFunctions(coll); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkFunctionOutputField(fSchema *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
|
|
switch fSchema.GetType() {
|
|
case schemapb.FunctionType_BM25:
|
|
if len(fields) != 1 {
|
|
return fmt.Errorf("BM25 function only need 1 output field, but got %d", len(fields))
|
|
}
|
|
|
|
if !typeutil.IsSparseFloatVectorType(fields[0].GetDataType()) {
|
|
return fmt.Errorf("BM25 function output field must be a SparseFloatVector field, but got %s", fields[0].DataType.String())
|
|
}
|
|
case schemapb.FunctionType_TextEmbedding:
|
|
if err := embedding.TextEmbeddingOutputsCheck(fields); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return errors.New("check output field for unknown function type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkFunctionInputField(function *schemapb.FunctionSchema, fields []*schemapb.FieldSchema) error {
|
|
switch function.GetType() {
|
|
case schemapb.FunctionType_BM25:
|
|
if len(fields) != 1 || (fields[0].DataType != schemapb.DataType_VarChar && fields[0].DataType != schemapb.DataType_Text) {
|
|
return fmt.Errorf("BM25 function input field must be a VARCHAR/TEXT field, got %d field with type %s",
|
|
len(fields), fields[0].DataType.String())
|
|
}
|
|
h := typeutil.CreateFieldSchemaHelper(fields[0])
|
|
if !h.EnableAnalyzer() {
|
|
return errors.New("BM25 function input field must set enable_analyzer to true")
|
|
}
|
|
case schemapb.FunctionType_TextEmbedding:
|
|
if err := embedding.TextEmbeddingInputsCheck(function.GetName(), fields); err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return errors.New("check input field with unknown function type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkFunctionBasicParams(function *schemapb.FunctionSchema) error {
|
|
if function.GetName() == "" {
|
|
return errors.New("function name cannot be empty")
|
|
}
|
|
if len(function.GetInputFieldNames()) == 0 {
|
|
return fmt.Errorf("function input field names cannot be empty, function: %s", function.GetName())
|
|
}
|
|
if len(function.GetOutputFieldNames()) == 0 {
|
|
return fmt.Errorf("function output field names cannot be empty, function: %s", function.GetName())
|
|
}
|
|
for _, input := range function.GetInputFieldNames() {
|
|
if input == "" {
|
|
return fmt.Errorf("function input field name cannot be empty string, function: %s", function.GetName())
|
|
}
|
|
// if input occurs more than once, error
|
|
if lo.Count(function.GetInputFieldNames(), input) > 1 {
|
|
return fmt.Errorf("each function input field should be used exactly once in the same function, function: %s, input field: %s", function.GetName(), input)
|
|
}
|
|
}
|
|
for _, output := range function.GetOutputFieldNames() {
|
|
if output == "" {
|
|
return fmt.Errorf("function output field name cannot be empty string, function: %s", function.GetName())
|
|
}
|
|
if lo.Count(function.GetInputFieldNames(), output) > 0 {
|
|
return fmt.Errorf("a single field cannot be both input and output in the same function, function: %s, field: %s", function.GetName(), output)
|
|
}
|
|
if lo.Count(function.GetOutputFieldNames(), output) > 1 {
|
|
return fmt.Errorf("each function output field should be used exactly once in the same function, function: %s, output field: %s", function.GetName(), output)
|
|
}
|
|
}
|
|
switch function.GetType() {
|
|
case schemapb.FunctionType_BM25:
|
|
if len(function.GetParams()) != 0 {
|
|
return errors.New("BM25 function accepts no params")
|
|
}
|
|
case schemapb.FunctionType_TextEmbedding:
|
|
if len(function.GetParams()) == 0 {
|
|
return errors.New("TextEmbedding function accepts no params")
|
|
}
|
|
default:
|
|
return errors.New("check function params with unknown function type")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// validateMultipleVectorFields check if schema has multiple vector fields.
|
|
func validateMultipleVectorFields(schema *schemapb.CollectionSchema) error {
|
|
vecExist := false
|
|
var vecName string
|
|
|
|
for i := range schema.Fields {
|
|
name := schema.Fields[i].Name
|
|
dType := schema.Fields[i].DataType
|
|
isVec := typeutil.IsVectorType(dType)
|
|
if isVec && vecExist && !enableMultipleVectorFields {
|
|
return fmt.Errorf(
|
|
"multiple vector fields is not supported, fields name: %s, %s",
|
|
vecName,
|
|
name,
|
|
)
|
|
} else if isVec {
|
|
vecExist = true
|
|
vecName = name
|
|
}
|
|
}
|
|
|
|
// todo(Spadea): should be there any check between vectors in struct fields?
|
|
|
|
return nil
|
|
}
|
|
|
|
func validateLoadFieldsList(schema *schemapb.CollectionSchema) error {
|
|
var vectorCnt int
|
|
for _, field := range schema.Fields {
|
|
shouldLoad, err := common.ShouldFieldBeLoaded(field.GetTypeParams())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// shoud load field, skip other check
|
|
if shouldLoad {
|
|
if typeutil.IsVectorType(field.GetDataType()) {
|
|
vectorCnt++
|
|
}
|
|
continue
|
|
}
|
|
|
|
if field.IsPrimaryKey {
|
|
return merr.WrapErrParameterInvalidMsg("Primary key field %s cannot skip loading", field.GetName())
|
|
}
|
|
|
|
if field.IsPartitionKey {
|
|
return merr.WrapErrParameterInvalidMsg("Partition Key field %s cannot skip loading", field.GetName())
|
|
}
|
|
|
|
if field.IsClusteringKey {
|
|
return merr.WrapErrParameterInvalidMsg("Clustering Key field %s cannot skip loading", field.GetName())
|
|
}
|
|
}
|
|
|
|
for _, structArrayField := range schema.StructArrayFields {
|
|
for _, field := range structArrayField.Fields {
|
|
shouldLoad, err := common.ShouldFieldBeLoaded(field.GetTypeParams())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if shouldLoad {
|
|
if typeutil.IsVectorType(field.ElementType) {
|
|
vectorCnt++
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
|
|
if vectorCnt == 0 {
|
|
return merr.WrapErrParameterInvalidMsg("cannot config all vector field(s) skip loading")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// parsePrimaryFieldData2IDs get IDs to fill grpc result, for example insert request, delete request etc.
|
|
func parsePrimaryFieldData2IDs(fieldData *schemapb.FieldData) (*schemapb.IDs, error) {
|
|
primaryData := &schemapb.IDs{}
|
|
switch fieldData.Field.(type) {
|
|
case *schemapb.FieldData_Scalars:
|
|
scalarField := fieldData.GetScalars()
|
|
switch scalarField.Data.(type) {
|
|
case *schemapb.ScalarField_LongData:
|
|
primaryData.IdField = &schemapb.IDs_IntId{
|
|
IntId: scalarField.GetLongData(),
|
|
}
|
|
case *schemapb.ScalarField_StringData:
|
|
primaryData.IdField = &schemapb.IDs_StrId{
|
|
StrId: scalarField.GetStringData(),
|
|
}
|
|
default:
|
|
return nil, merr.WrapErrParameterInvalidMsg("currently only support DataType Int64 or VarChar as PrimaryField")
|
|
}
|
|
default:
|
|
return nil, merr.WrapErrParameterInvalidMsg("currently not support vector field as PrimaryField")
|
|
}
|
|
|
|
return primaryData, nil
|
|
}
|
|
|
|
// autoGenPrimaryFieldData generate primary data when autoID == true
|
|
func autoGenPrimaryFieldData(fieldSchema *schemapb.FieldSchema, data interface{}) (*schemapb.FieldData, error) {
|
|
var fieldData schemapb.FieldData
|
|
fieldData.FieldName = fieldSchema.Name
|
|
fieldData.Type = fieldSchema.DataType
|
|
switch data := data.(type) {
|
|
case []int64:
|
|
switch fieldData.Type {
|
|
case schemapb.DataType_Int64:
|
|
fieldData.Field = &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: data,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
case schemapb.DataType_VarChar:
|
|
strIDs := make([]string, len(data))
|
|
for i, v := range data {
|
|
strIDs[i] = strconv.FormatInt(v, 10)
|
|
}
|
|
fieldData.Field = &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: strIDs,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
default:
|
|
return nil, errors.New("currently only support autoID for int64 and varchar PrimaryField")
|
|
}
|
|
default:
|
|
return nil, errors.New("currently only int64 is supported as the data source for the autoID of a PrimaryField")
|
|
}
|
|
|
|
return &fieldData, nil
|
|
}
|
|
|
|
func autoGenDynamicFieldData(data [][]byte) *schemapb.FieldData {
|
|
return &schemapb.FieldData{
|
|
FieldName: common.MetaFieldName,
|
|
Type: schemapb.DataType_JSON,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_JsonData{
|
|
JsonData: &schemapb.JSONArray{
|
|
Data: data,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
IsDynamic: true,
|
|
}
|
|
}
|
|
|
|
// fillFieldPropertiesBySchema set fieldID to fieldData according FieldSchemas
|
|
func fillFieldPropertiesBySchema(columns []*schemapb.FieldData, schema *schemapb.CollectionSchema) error {
|
|
fieldName2Schema := make(map[string]*schemapb.FieldSchema)
|
|
|
|
expectColumnNum := 0
|
|
for _, field := range schema.GetFields() {
|
|
fieldName2Schema[field.Name] = field
|
|
if !typeutil.IsBM25FunctionOutputField(field, schema) {
|
|
expectColumnNum++
|
|
}
|
|
}
|
|
|
|
for _, structField := range schema.GetStructArrayFields() {
|
|
for _, field := range structField.GetFields() {
|
|
fieldName2Schema[field.Name] = field
|
|
expectColumnNum++
|
|
}
|
|
}
|
|
|
|
if len(columns) != expectColumnNum {
|
|
return fmt.Errorf("len(columns) mismatch the expectColumnNum, expectColumnNum: %d, len(columns): %d",
|
|
expectColumnNum, len(columns))
|
|
}
|
|
|
|
for _, fieldData := range columns {
|
|
if fieldSchema, ok := fieldName2Schema[fieldData.FieldName]; ok {
|
|
fieldData.FieldId = fieldSchema.FieldID
|
|
fieldData.Type = fieldSchema.DataType
|
|
|
|
// Set the ElementType because it may not be set in the insert request.
|
|
if fieldData.Type == schemapb.DataType_Array {
|
|
fd, ok := fieldData.Field.(*schemapb.FieldData_Scalars)
|
|
if !ok || fd.Scalars.GetArrayData() == nil {
|
|
return fmt.Errorf("field convert FieldData_Scalars fail in fieldData, fieldName: %s,"+
|
|
" collectionName:%s", fieldData.FieldName, schema.Name)
|
|
}
|
|
fd.Scalars.GetArrayData().ElementType = fieldSchema.ElementType
|
|
} else if fieldData.Type == schemapb.DataType_ArrayOfVector {
|
|
fd, ok := fieldData.Field.(*schemapb.FieldData_Vectors)
|
|
if !ok || fd.Vectors.GetVectorArray() == nil {
|
|
return fmt.Errorf("field convert FieldData_Vectors fail in fieldData, fieldName: %s,"+
|
|
" collectionName:%s", fieldData.FieldName, schema.Name)
|
|
}
|
|
fd.Vectors.GetVectorArray().ElementType = fieldSchema.ElementType
|
|
}
|
|
} else {
|
|
return fmt.Errorf("fieldName %v not exist in collection schema", fieldData.FieldName)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func ValidateUsername(username string) error {
|
|
username = strings.TrimSpace(username)
|
|
|
|
if username == "" {
|
|
return merr.WrapErrParameterInvalidMsg("username must be not empty")
|
|
}
|
|
|
|
if len(username) > Params.ProxyCfg.MaxUsernameLength.GetAsInt() {
|
|
return merr.WrapErrParameterInvalidMsg("invalid username %s with length %d, the length of username must be less than %d", username, len(username), Params.ProxyCfg.MaxUsernameLength.GetValue())
|
|
}
|
|
|
|
firstChar := username[0]
|
|
if !isAlpha(firstChar) {
|
|
return merr.WrapErrParameterInvalidMsg("invalid user name %s, the first character must be a letter, but got %s", username, string(firstChar))
|
|
}
|
|
|
|
usernameSize := len(username)
|
|
for i := 1; i < usernameSize; i++ {
|
|
c := username[i]
|
|
if c != '_' && c != '-' && c != '.' && !isAlpha(c) && !isNumber(c) {
|
|
return merr.WrapErrParameterInvalidMsg("invalid user name %s, username must contain only numbers, letters, underscores, dots, and hyphens, but got %s", username, c)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidatePassword(password string) error {
|
|
if len(password) < Params.ProxyCfg.MinPasswordLength.GetAsInt() || len(password) > Params.ProxyCfg.MaxPasswordLength.GetAsInt() {
|
|
return merr.WrapErrParameterInvalidRange(Params.ProxyCfg.MinPasswordLength.GetAsInt(),
|
|
Params.ProxyCfg.MaxPasswordLength.GetAsInt(),
|
|
len(password), "invalid password length")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ReplaceID2Name(oldStr string, id int64, name string) string {
|
|
return strings.ReplaceAll(oldStr, strconv.FormatInt(id, 10), name)
|
|
}
|
|
|
|
func parseGuaranteeTsFromConsistency(ts, tMax typeutil.Timestamp, consistency commonpb.ConsistencyLevel) typeutil.Timestamp {
|
|
switch consistency {
|
|
case commonpb.ConsistencyLevel_Strong:
|
|
ts = tMax
|
|
case commonpb.ConsistencyLevel_Bounded:
|
|
ratio := Params.CommonCfg.GracefulTime.GetAsDuration(time.Millisecond)
|
|
ts = tsoutil.AddPhysicalDurationOnTs(tMax, -ratio)
|
|
case commonpb.ConsistencyLevel_Eventually:
|
|
ts = 1
|
|
}
|
|
return ts
|
|
}
|
|
|
|
func parseGuaranteeTs(ts, tMax typeutil.Timestamp) typeutil.Timestamp {
|
|
switch ts {
|
|
case strongTS:
|
|
ts = tMax
|
|
case boundedTS:
|
|
ratio := Params.CommonCfg.GracefulTime.GetAsDuration(time.Millisecond)
|
|
ts = tsoutil.AddPhysicalDurationOnTs(tMax, -ratio)
|
|
}
|
|
return ts
|
|
}
|
|
|
|
func getMaxMvccTsFromChannels(channelsTs map[string]uint64, beginTs typeutil.Timestamp) typeutil.Timestamp {
|
|
maxTs := typeutil.Timestamp(0)
|
|
for _, ts := range channelsTs {
|
|
if ts > maxTs {
|
|
maxTs = ts
|
|
}
|
|
}
|
|
|
|
if maxTs == 0 {
|
|
log.Warn("no channel ts found, use beginTs instead")
|
|
return beginTs
|
|
}
|
|
|
|
return maxTs
|
|
}
|
|
|
|
func validateName(entity string, nameType string) error {
|
|
return validateNameWithCustomChars(entity, nameType, Params.ProxyCfg.NameValidationAllowedChars.GetValue())
|
|
}
|
|
|
|
func validateNameWithCustomChars(entity string, nameType string, allowedChars string) error {
|
|
entity = strings.TrimSpace(entity)
|
|
|
|
if entity == "" {
|
|
return merr.WrapErrParameterInvalid("not empty", entity, nameType+" should be not empty")
|
|
}
|
|
|
|
if len(entity) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
return merr.WrapErrParameterInvalidRange(0,
|
|
Params.ProxyCfg.MaxNameLength.GetAsInt(),
|
|
len(entity),
|
|
fmt.Sprintf("the length of %s must be not greater than limit", nameType))
|
|
}
|
|
|
|
firstChar := entity[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) {
|
|
return merr.WrapErrParameterInvalid('_',
|
|
firstChar,
|
|
fmt.Sprintf("the first character of %s must be an underscore or letter", nameType))
|
|
}
|
|
|
|
for i := 1; i < len(entity); i++ {
|
|
c := entity[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) && !strings.ContainsRune(allowedChars, rune(c)) {
|
|
return merr.WrapErrParameterInvalidMsg("%s can only contain numbers, letters, underscores, and allowed characters (%s), found %c at %d", nameType, allowedChars, c, i)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ValidateRoleName(entity string) error {
|
|
return validateNameWithCustomChars(entity, "role name", Params.ProxyCfg.RoleNameValidationAllowedChars.GetValue())
|
|
}
|
|
|
|
func IsDefaultRole(roleName string) bool {
|
|
for _, defaultRole := range util.DefaultRoles {
|
|
if defaultRole == roleName {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func ValidateObjectName(entity string) error {
|
|
if util.IsAnyWord(entity) {
|
|
return nil
|
|
}
|
|
return validateName(entity, "object name")
|
|
}
|
|
|
|
func ValidateCollectionName(entity string) error {
|
|
if util.IsAnyWord(entity) {
|
|
return nil
|
|
}
|
|
return validateName(entity, "collection name")
|
|
}
|
|
|
|
func ValidateObjectType(entity string) error {
|
|
return validateName(entity, "ObjectType")
|
|
}
|
|
|
|
func ValidatePrivilege(entity string) error {
|
|
if util.IsAnyWord(entity) {
|
|
return nil
|
|
}
|
|
return validateName(entity, "Privilege")
|
|
}
|
|
|
|
func GetCurUserFromContext(ctx context.Context) (string, error) {
|
|
return contextutil.GetCurUserFromContext(ctx)
|
|
}
|
|
|
|
func GetCurUserFromContextOrDefault(ctx context.Context) string {
|
|
username, _ := GetCurUserFromContext(ctx)
|
|
return username
|
|
}
|
|
|
|
func GetCurDBNameFromContextOrDefault(ctx context.Context) string {
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return util.DefaultDBName
|
|
}
|
|
dbNameData := md[strings.ToLower(util.HeaderDBName)]
|
|
if len(dbNameData) < 1 || dbNameData[0] == "" {
|
|
return util.DefaultDBName
|
|
}
|
|
return dbNameData[0]
|
|
}
|
|
|
|
func NewContextWithMetadata(ctx context.Context, username string, dbName string) context.Context {
|
|
dbKey := strings.ToLower(util.HeaderDBName)
|
|
if dbName != "" {
|
|
ctx = contextutil.AppendToIncomingContext(ctx, dbKey, dbName)
|
|
}
|
|
if username != "" {
|
|
originValue := fmt.Sprintf("%s%s%s", username, util.CredentialSeperator, username)
|
|
authKey := strings.ToLower(util.HeaderAuthorize)
|
|
authValue := crypto.Base64Encode(originValue)
|
|
ctx = contextutil.AppendToIncomingContext(ctx, authKey, authValue)
|
|
}
|
|
return ctx
|
|
}
|
|
|
|
func AppendUserInfoForRPC(ctx context.Context) context.Context {
|
|
curUser, _ := GetCurUserFromContext(ctx)
|
|
if curUser != "" {
|
|
originValue := fmt.Sprintf("%s%s%s", curUser, util.CredentialSeperator, curUser)
|
|
authKey := strings.ToLower(util.HeaderAuthorize)
|
|
authValue := crypto.Base64Encode(originValue)
|
|
ctx = metadata.AppendToOutgoingContext(ctx, authKey, authValue)
|
|
}
|
|
return ctx
|
|
}
|
|
|
|
func GetRole(username string) ([]string, error) {
|
|
privCache := privilege.GetPrivilegeCache()
|
|
if privCache == nil {
|
|
return []string{}, merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
|
}
|
|
return privCache.GetUserRole(username), nil
|
|
}
|
|
|
|
func PasswordVerify(ctx context.Context, username, rawPwd string) bool {
|
|
return passwordVerify(ctx, username, rawPwd, privilege.GetPrivilegeCache())
|
|
}
|
|
|
|
func VerifyAPIKey(rawToken string) (string, error) {
|
|
hoo := hookutil.GetHook()
|
|
user, err := hoo.VerifyAPIKey(rawToken)
|
|
if err != nil {
|
|
log.Warn("fail to verify apikey", zap.String("api_key", rawToken), zap.Error(err))
|
|
return "", merr.WrapErrParameterInvalidMsg("invalid apikey: [%s]", rawToken)
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
// PasswordVerify verify password
|
|
func passwordVerify(ctx context.Context, username, rawPwd string, privilegeCache privilege.PrivilegeCache) bool {
|
|
// it represents the cache miss if Sha256Password is empty within credInfo, which shall be updated first connection.
|
|
// meanwhile, generating Sha256Password depends on raw password and encrypted password will not cache.
|
|
credInfo, err := privilege.GetPrivilegeCache().GetCredentialInfo(ctx, username)
|
|
if err != nil {
|
|
log.Ctx(ctx).Error("found no credential", zap.String("username", username), zap.Error(err))
|
|
return false
|
|
}
|
|
|
|
// hit cache
|
|
sha256Pwd := crypto.SHA256(rawPwd, credInfo.Username)
|
|
if credInfo.Sha256Password != "" {
|
|
return sha256Pwd == credInfo.Sha256Password
|
|
}
|
|
|
|
// miss cache, verify against encrypted password from etcd
|
|
if err := bcrypt.CompareHashAndPassword([]byte(credInfo.EncryptedPassword), []byte(rawPwd)); err != nil {
|
|
log.Ctx(ctx).Error("Verify password failed", zap.Error(err))
|
|
return false
|
|
}
|
|
|
|
// update cache after miss cache
|
|
credInfo.Sha256Password = sha256Pwd
|
|
log.Ctx(ctx).Debug("get credential miss cache, update cache with", zap.Any("credential", credInfo))
|
|
privilegeCache.UpdateCredential(credInfo)
|
|
return true
|
|
}
|
|
|
|
func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int64) {
|
|
pkNames := []string{}
|
|
fieldIDs := []int64{}
|
|
for _, field := range schema.Fields {
|
|
if field.IsPrimaryKey {
|
|
pkNames = append(pkNames, field.GetName())
|
|
fieldIDs = append(fieldIDs, field.GetFieldID())
|
|
}
|
|
}
|
|
return pkNames, fieldIDs
|
|
}
|
|
|
|
func recallCal[T string | int64](results []T, gts []T) float32 {
|
|
hit := 0
|
|
total := 0
|
|
for _, r := range results {
|
|
total++
|
|
for _, gt := range gts {
|
|
if r == gt {
|
|
hit++
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return float32(hit) / float32(total)
|
|
}
|
|
|
|
func computeRecall(results *schemapb.SearchResultData, gts *schemapb.SearchResultData) error {
|
|
if results.GetNumQueries() != gts.GetNumQueries() {
|
|
return fmt.Errorf("num of queries is inconsistent between search results(%d) and ground truth(%d)", results.GetNumQueries(), gts.GetNumQueries())
|
|
}
|
|
|
|
switch results.GetIds().GetIdField().(type) {
|
|
case *schemapb.IDs_IntId:
|
|
switch gts.GetIds().GetIdField().(type) {
|
|
case *schemapb.IDs_IntId:
|
|
currentResultIndex := int64(0)
|
|
currentGTIndex := int64(0)
|
|
recalls := make([]float32, 0, results.GetNumQueries())
|
|
for i := 0; i < int(results.GetNumQueries()); i++ {
|
|
currentResultTopk := results.GetTopks()[i]
|
|
currentGTTopk := gts.GetTopks()[i]
|
|
recalls = append(recalls, recallCal(results.GetIds().GetIntId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
|
|
gts.GetIds().GetIntId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
|
|
currentResultIndex += currentResultTopk
|
|
currentGTIndex += currentGTTopk
|
|
}
|
|
results.Recalls = recalls
|
|
return nil
|
|
case *schemapb.IDs_StrId:
|
|
return errors.New("pk type is inconsistent between search results(int64) and ground truth(string)")
|
|
default:
|
|
return errors.New("unsupported pk type")
|
|
}
|
|
|
|
case *schemapb.IDs_StrId:
|
|
switch gts.GetIds().GetIdField().(type) {
|
|
case *schemapb.IDs_StrId:
|
|
currentResultIndex := int64(0)
|
|
currentGTIndex := int64(0)
|
|
recalls := make([]float32, 0, results.GetNumQueries())
|
|
for i := 0; i < int(results.GetNumQueries()); i++ {
|
|
currentResultTopk := results.GetTopks()[i]
|
|
currentGTTopk := gts.GetTopks()[i]
|
|
recalls = append(recalls, recallCal(results.GetIds().GetStrId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
|
|
gts.GetIds().GetStrId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
|
|
currentResultIndex += currentResultTopk
|
|
currentGTIndex += currentGTTopk
|
|
}
|
|
results.Recalls = recalls
|
|
return nil
|
|
case *schemapb.IDs_IntId:
|
|
return errors.New("pk type is inconsistent between search results(string) and ground truth(int64)")
|
|
default:
|
|
return errors.New("unsupported pk type")
|
|
}
|
|
default:
|
|
return errors.New("unsupported pk type")
|
|
}
|
|
}
|
|
|
|
// Support wildcard in output fields:
|
|
//
|
|
// "*" - all fields
|
|
//
|
|
// For example, A and B are scalar fields, C and D are vector fields, duplicated fields will automatically be removed.
|
|
//
|
|
// output_fields=["*"] ==> [A,B,C,D]
|
|
// output_fields=["*",A] ==> [A,B,C,D]
|
|
// output_fields=["*",C] ==> [A,B,C,D]
|
|
//
|
|
// 4th return value is true if user requested pk field explicitly or using wildcard.
|
|
// if removePkField is true, pk field will not be include in the first(resultFieldNames)/second(userOutputFields)
|
|
// return value.
|
|
func translateOutputFields(outputFields []string, schema *schemaInfo, removePkField bool) ([]string, []string, []string, bool, error) {
|
|
var primaryFieldName string
|
|
allFieldNameMap := make(map[string]*schemapb.FieldSchema)
|
|
resultFieldNameMap := make(map[string]bool)
|
|
resultFieldNames := make([]string, 0)
|
|
userOutputFieldsMap := make(map[string]bool)
|
|
userOutputFields := make([]string, 0)
|
|
userDynamicFieldsMap := make(map[string]bool)
|
|
userDynamicFields := make([]string, 0)
|
|
useAllDyncamicFields := false
|
|
for _, field := range schema.Fields {
|
|
if field.IsPrimaryKey {
|
|
primaryFieldName = field.Name
|
|
}
|
|
allFieldNameMap[field.Name] = field
|
|
}
|
|
|
|
// User may specify a struct array field or some specific fields in the struct array field
|
|
for _, subStruct := range schema.StructArrayFields {
|
|
for _, field := range subStruct.Fields {
|
|
allFieldNameMap[field.Name] = field
|
|
}
|
|
}
|
|
|
|
structArrayNameToFields := make(map[string][]*schemapb.FieldSchema)
|
|
for _, subStruct := range schema.StructArrayFields {
|
|
structArrayNameToFields[subStruct.Name] = subStruct.Fields
|
|
}
|
|
|
|
userRequestedPkFieldExplicitly := false
|
|
|
|
for _, outputFieldName := range outputFields {
|
|
outputFieldName = strings.TrimSpace(outputFieldName)
|
|
if outputFieldName == primaryFieldName {
|
|
userRequestedPkFieldExplicitly = true
|
|
}
|
|
if outputFieldName == "*" {
|
|
userRequestedPkFieldExplicitly = true
|
|
for fieldName, field := range allFieldNameMap {
|
|
if schema.CanRetrieveRawFieldData(field) {
|
|
resultFieldNameMap[fieldName] = true
|
|
userOutputFieldsMap[fieldName] = true
|
|
}
|
|
}
|
|
useAllDyncamicFields = true
|
|
} else {
|
|
if structArrayField, ok := structArrayNameToFields[outputFieldName]; ok {
|
|
for _, field := range structArrayField {
|
|
if schema.CanRetrieveRawFieldData(field) {
|
|
resultFieldNameMap[field.Name] = true
|
|
userOutputFieldsMap[field.Name] = true
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
if field, ok := allFieldNameMap[outputFieldName]; ok {
|
|
if !schema.CanRetrieveRawFieldData(field) {
|
|
return nil, nil, nil, false, fmt.Errorf("not allowed to retrieve raw data of field %s", outputFieldName)
|
|
}
|
|
resultFieldNameMap[outputFieldName] = true
|
|
userOutputFieldsMap[outputFieldName] = true
|
|
} else {
|
|
if schema.EnableDynamicField {
|
|
dynamicNestedPath := outputFieldName
|
|
err := planparserv2.ParseIdentifier(schema.schemaHelper, outputFieldName, func(expr *planpb.Expr) error {
|
|
columnInfo := expr.GetColumnExpr().GetInfo()
|
|
// there must be no error here
|
|
dynamicField, _ := schema.schemaHelper.GetDynamicField()
|
|
// only $meta["xxx"] is allowed for now
|
|
if dynamicField.GetFieldID() != columnInfo.GetFieldId() {
|
|
return errors.New("not support getting subkeys of json field yet")
|
|
}
|
|
nestedPaths := columnInfo.GetNestedPath()
|
|
// $meta["A"]["B"] not allowed for now
|
|
if len(nestedPaths) != 1 {
|
|
return errors.New("not support getting multiple level of dynamic field for now")
|
|
}
|
|
// $meta["dyn_field"], output field name could be:
|
|
// 1. "dyn_field", outputFieldName == nestedPath
|
|
// 2. `$meta["dyn_field"]` explicit form
|
|
if nestedPaths[0] != outputFieldName {
|
|
// use "dyn_field" as userDynamicFieldsMap when outputField = `$meta["dyn_field"]`
|
|
dynamicNestedPath = nestedPaths[0]
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
log.Info("parse output field name failed", zap.String("field name", outputFieldName), zap.Error(err))
|
|
return nil, nil, nil, false, fmt.Errorf("parse output field name failed: %s", outputFieldName)
|
|
}
|
|
resultFieldNameMap[common.MetaFieldName] = true
|
|
userOutputFieldsMap[outputFieldName] = true
|
|
userDynamicFieldsMap[dynamicNestedPath] = true
|
|
} else {
|
|
return nil, nil, nil, false, fmt.Errorf("field %s not exist", outputFieldName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if removePkField {
|
|
delete(resultFieldNameMap, primaryFieldName)
|
|
delete(userOutputFieldsMap, primaryFieldName)
|
|
}
|
|
|
|
for fieldName := range resultFieldNameMap {
|
|
resultFieldNames = append(resultFieldNames, fieldName)
|
|
}
|
|
for fieldName := range userOutputFieldsMap {
|
|
userOutputFields = append(userOutputFields, fieldName)
|
|
}
|
|
if !useAllDyncamicFields {
|
|
for fieldName := range userDynamicFieldsMap {
|
|
userDynamicFields = append(userDynamicFields, fieldName)
|
|
}
|
|
}
|
|
|
|
return resultFieldNames, userOutputFields, userDynamicFields, userRequestedPkFieldExplicitly, nil
|
|
}
|
|
|
|
func validateIndexName(indexName string) error {
|
|
indexName = strings.TrimSpace(indexName)
|
|
|
|
if indexName == "" {
|
|
return nil
|
|
}
|
|
invalidMsg := "Invalid index name: " + indexName + ". "
|
|
if len(indexName) > Params.ProxyCfg.MaxNameLength.GetAsInt() {
|
|
msg := invalidMsg + "The length of a index name must be less than " + Params.ProxyCfg.MaxNameLength.GetValue() + " characters."
|
|
return errors.New(msg)
|
|
}
|
|
|
|
firstChar := indexName[0]
|
|
if firstChar != '_' && !isAlpha(firstChar) {
|
|
msg := invalidMsg + "The first character of a index name must be an underscore or letter."
|
|
return errors.New(msg)
|
|
}
|
|
|
|
indexNameSize := len(indexName)
|
|
for i := 1; i < indexNameSize; i++ {
|
|
c := indexName[i]
|
|
if c != '_' && !isAlpha(c) && !isNumber(c) {
|
|
msg := invalidMsg + "Index name can only contain numbers, letters, and underscores."
|
|
return errors.New(msg)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isCollectionLoaded(ctx context.Context, mc types.MixCoordClient, collID int64) (bool, error) {
|
|
// get all loading collections
|
|
resp, err := mc.ShowLoadCollections(ctx, &querypb.ShowCollectionsRequest{
|
|
CollectionIDs: nil,
|
|
})
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
if resp.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
|
|
return false, merr.Error(resp.GetStatus())
|
|
}
|
|
|
|
for _, loadedCollID := range resp.GetCollectionIDs() {
|
|
if collID == loadedCollID {
|
|
return true, nil
|
|
}
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
func isPartitionLoaded(ctx context.Context, mc types.MixCoordClient, collID int64, partID int64) (bool, error) {
|
|
// get all loading collections
|
|
resp, err := mc.ShowLoadPartitions(ctx, &querypb.ShowPartitionsRequest{
|
|
CollectionID: collID,
|
|
PartitionIDs: []int64{partID},
|
|
})
|
|
if err := merr.CheckRPCCall(resp, err); err != nil {
|
|
// qc returns error if partition not loaded
|
|
if errors.Is(err, merr.ErrPartitionNotLoaded) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func checkFieldsDataBySchema(allFields []*schemapb.FieldSchema, schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg, inInsert bool) error {
|
|
log := log.With(zap.String("collection", schema.GetName()))
|
|
primaryKeyNum := 0
|
|
autoGenFieldNum := 0
|
|
|
|
dataNameSet := typeutil.NewSet[string]()
|
|
for _, data := range insertMsg.FieldsData {
|
|
fieldName := data.GetFieldName()
|
|
if dataNameSet.Contain(fieldName) {
|
|
return merr.WrapErrParameterInvalidMsg("duplicated field %s found", fieldName)
|
|
}
|
|
dataNameSet.Insert(fieldName)
|
|
}
|
|
|
|
allowInsertAutoID, _ := common.IsAllowInsertAutoID(schema.GetProperties()...)
|
|
hasPkData := false
|
|
needAutoGenPk := false
|
|
|
|
for _, fieldSchema := range allFields {
|
|
if fieldSchema.AutoID && !fieldSchema.IsPrimaryKey {
|
|
log.Warn("not primary key field, but set autoID true", zap.String("field", fieldSchema.GetName()))
|
|
return merr.WrapErrParameterInvalidMsg("only primary key could be with AutoID enabled")
|
|
}
|
|
|
|
if fieldSchema.IsPrimaryKey {
|
|
primaryKeyNum++
|
|
hasPkData = dataNameSet.Contain(fieldSchema.GetName())
|
|
needAutoGenPk = fieldSchema.AutoID && (!allowInsertAutoID || !hasPkData)
|
|
}
|
|
if fieldSchema.GetDefaultValue() != nil && fieldSchema.IsPrimaryKey {
|
|
return merr.WrapErrParameterInvalidMsg("primary key can't be with default value")
|
|
}
|
|
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && needAutoGenPk && inInsert) || typeutil.IsBM25FunctionOutputField(fieldSchema, schema) {
|
|
// when inInsert, no need to pass when pk is autoid and SkipAutoIDCheck is false
|
|
autoGenFieldNum++
|
|
}
|
|
if _, ok := dataNameSet[fieldSchema.GetName()]; !ok {
|
|
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && needAutoGenPk && inInsert) || typeutil.IsBM25FunctionOutputField(fieldSchema, schema) {
|
|
// autoGenField
|
|
continue
|
|
}
|
|
|
|
if fieldSchema.GetDefaultValue() == nil && !fieldSchema.GetNullable() {
|
|
log.Warn("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName()))
|
|
return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName())
|
|
}
|
|
// when use default_value or has set Nullable
|
|
// it's ok that no corresponding fieldData found
|
|
dataToAppend, err := typeutil.GenEmptyFieldData(fieldSchema)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dataToAppend.ValidData = make([]bool, insertMsg.GetNumRows())
|
|
insertMsg.FieldsData = append(insertMsg.FieldsData, dataToAppend)
|
|
}
|
|
}
|
|
|
|
if primaryKeyNum > 1 {
|
|
log.Warn("more than 1 primary keys not supported",
|
|
zap.Int64("primaryKeyNum", int64(primaryKeyNum)))
|
|
return merr.WrapErrParameterInvalidMsg("more than 1 primary keys not supported, got %d", primaryKeyNum)
|
|
}
|
|
expectedNum := len(allFields)
|
|
actualNum := len(insertMsg.FieldsData) + autoGenFieldNum
|
|
|
|
if expectedNum != actualNum {
|
|
log.Warn("the number of fields is not the same as needed", zap.Int("expected", expectedNum), zap.Int("actual", actualNum))
|
|
return merr.WrapErrParameterInvalid(expectedNum, actualNum, "more fieldData has pass in")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// checkAndFlattenStructFieldData verifies the array length of the struct array field data in the insert message
|
|
// and then flattens the data so that data node and query node have not to handle the struct array field data.
|
|
func checkAndFlattenStructFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
|
|
structSchemaMap := make(map[string]*schemapb.StructArrayFieldSchema, len(schema.GetStructArrayFields()))
|
|
for _, structField := range schema.GetStructArrayFields() {
|
|
structSchemaMap[structField.Name] = structField
|
|
}
|
|
|
|
fieldSchemaMap := make(map[string]*schemapb.FieldSchema, len(schema.GetFields()))
|
|
for _, fieldSchema := range schema.GetFields() {
|
|
fieldSchemaMap[fieldSchema.Name] = fieldSchema
|
|
}
|
|
|
|
structFieldCount := 0
|
|
flattenedFields := make([]*schemapb.FieldData, 0, len(insertMsg.GetFieldsData())+5)
|
|
|
|
for _, fieldData := range insertMsg.GetFieldsData() {
|
|
if _, ok := fieldSchemaMap[fieldData.FieldName]; ok {
|
|
flattenedFields = append(flattenedFields, fieldData)
|
|
continue
|
|
}
|
|
|
|
structName := fieldData.FieldName
|
|
structSchema, ok := structSchemaMap[structName]
|
|
if !ok {
|
|
return fmt.Errorf("fieldName %v not exist in collection schema, fieldType %v, fieldId %v", fieldData.FieldName, fieldData.Type, fieldData.FieldId)
|
|
}
|
|
|
|
structFieldCount++
|
|
structArrays, ok := fieldData.Field.(*schemapb.FieldData_StructArrays)
|
|
if !ok {
|
|
return fmt.Errorf("field convert FieldData_StructArrays fail in fieldData, fieldName: %s,"+
|
|
" collectionName:%s", structName, schema.Name)
|
|
}
|
|
|
|
if len(structArrays.StructArrays.Fields) != len(structSchema.GetFields()) {
|
|
return fmt.Errorf("length of fields of struct field mismatch length of the fields in schema, fieldName: %s,"+
|
|
" collectionName:%s, fieldData fields length:%d, schema fields length:%d",
|
|
structName, schema.Name, len(structArrays.StructArrays.Fields), len(structSchema.GetFields()))
|
|
}
|
|
|
|
// Check the array length of the struct array field data
|
|
expectedArrayLen := -1
|
|
for _, subField := range structArrays.StructArrays.Fields {
|
|
var currentArrayLen int
|
|
|
|
switch subFieldData := subField.Field.(type) {
|
|
case *schemapb.FieldData_Scalars:
|
|
if scalarArray := subFieldData.Scalars.GetArrayData(); scalarArray != nil {
|
|
currentArrayLen = len(scalarArray.Data)
|
|
} else {
|
|
return fmt.Errorf("scalar array data is nil in struct field '%s', sub-field '%s'",
|
|
structName, subField.FieldName)
|
|
}
|
|
case *schemapb.FieldData_Vectors:
|
|
if vectorArray := subFieldData.Vectors.GetVectorArray(); vectorArray != nil {
|
|
currentArrayLen = len(vectorArray.Data)
|
|
} else {
|
|
return fmt.Errorf("vector array data is nil in struct field '%s', sub-field '%s'",
|
|
structName, subField.FieldName)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unexpected field data type in struct array field, fieldName: %s", structName)
|
|
}
|
|
|
|
if expectedArrayLen == -1 {
|
|
expectedArrayLen = currentArrayLen
|
|
} else if currentArrayLen != expectedArrayLen {
|
|
return fmt.Errorf("inconsistent array length in struct field '%s': expected %d, got %d for sub-field '%s'",
|
|
structName, expectedArrayLen, currentArrayLen, subField.FieldName)
|
|
}
|
|
|
|
transformedFieldName := typeutil.ConcatStructFieldName(structName, subField.FieldName)
|
|
subFieldCopy := &schemapb.FieldData{
|
|
FieldName: transformedFieldName,
|
|
FieldId: subField.FieldId,
|
|
Type: subField.Type,
|
|
Field: subField.Field,
|
|
IsDynamic: subField.IsDynamic,
|
|
}
|
|
|
|
flattenedFields = append(flattenedFields, subFieldCopy)
|
|
}
|
|
}
|
|
|
|
if len(schema.GetStructArrayFields()) != structFieldCount {
|
|
return fmt.Errorf("the number of struct array fields is not the same as needed, expected: %d, actual: %d",
|
|
len(schema.GetStructArrayFields()), structFieldCount)
|
|
}
|
|
|
|
insertMsg.FieldsData = flattenedFields
|
|
return nil
|
|
}
|
|
|
|
func checkPrimaryFieldData(allFields []*schemapb.FieldSchema, schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) {
|
|
log := log.With(zap.String("collectionName", insertMsg.CollectionName))
|
|
rowNums := uint32(insertMsg.NRows())
|
|
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
|
if insertMsg.NRows() <= 0 {
|
|
return nil, merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(rowNums), "num_rows should be greater than 0")
|
|
}
|
|
|
|
if err := checkFieldsDataBySchema(allFields, schema, insertMsg, true); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
|
|
if err != nil {
|
|
log.Error("get primary field schema failed", zap.Any("schema", schema), zap.Error(err))
|
|
return nil, err
|
|
}
|
|
if primaryFieldSchema.GetNullable() {
|
|
return nil, merr.WrapErrParameterInvalidMsg("primary field not support null")
|
|
}
|
|
var primaryFieldData *schemapb.FieldData
|
|
// when checkPrimaryFieldData in insert
|
|
|
|
allowInsertAutoID, _ := common.IsAllowInsertAutoID(schema.GetProperties()...)
|
|
skipAutoIDCheck := primaryFieldSchema.AutoID &&
|
|
typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) && (Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() || allowInsertAutoID)
|
|
|
|
if !primaryFieldSchema.AutoID || skipAutoIDCheck {
|
|
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
|
|
if err != nil {
|
|
log.Info("get primary field data failed", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
} else {
|
|
// check primary key data not exist
|
|
if typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) {
|
|
return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("can not assign primary field data when auto id enabled and allow_insert_auto_id is false %v", primaryFieldSchema.Name))
|
|
}
|
|
// if autoID == true, currently support autoID for int64 and varchar PrimaryField
|
|
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
|
|
if err != nil {
|
|
log.Info("generate primary field data failed when autoID == true", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
// if autoID == true, set the primary field data
|
|
// insertMsg.fieldsData need append primaryFieldData
|
|
insertMsg.FieldsData = append(insertMsg.FieldsData, primaryFieldData)
|
|
}
|
|
|
|
// parse primaryFieldData to result.IDs, and as returned primary keys
|
|
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
|
if err != nil {
|
|
log.Warn("parse primary field data to IDs failed", zap.Error(err))
|
|
return nil, err
|
|
}
|
|
|
|
return ids, nil
|
|
}
|
|
|
|
// check whether insertMsg has all fields in schema
|
|
func LackOfFieldsDataBySchema(schema *schemapb.CollectionSchema, fieldsData []*schemapb.FieldData, skipPkFieldCheck bool, skipDynamicFieldCheck bool) error {
|
|
log := log.With(zap.String("collection", schema.GetName()))
|
|
|
|
// find bm25 generated fields
|
|
bm25Fields := typeutil.NewSet[string](GetFunctionOutputFields(schema)...)
|
|
dataNameMap := make(map[string]*schemapb.FieldData)
|
|
for _, data := range fieldsData {
|
|
dataNameMap[data.GetFieldName()] = data
|
|
}
|
|
|
|
for _, fieldSchema := range schema.Fields {
|
|
if bm25Fields.Contain(fieldSchema.GetName()) {
|
|
continue
|
|
}
|
|
|
|
if fieldSchema.GetNullable() || fieldSchema.GetDefaultValue() != nil {
|
|
continue
|
|
}
|
|
|
|
if _, ok := dataNameMap[fieldSchema.GetName()]; !ok {
|
|
if (fieldSchema.IsPrimaryKey && fieldSchema.AutoID && !Params.ProxyCfg.SkipAutoIDCheck.GetAsBool() && skipPkFieldCheck) ||
|
|
typeutil.IsBM25FunctionOutputField(fieldSchema, schema) ||
|
|
(skipDynamicFieldCheck && fieldSchema.GetIsDynamic()) {
|
|
// autoGenField
|
|
continue
|
|
}
|
|
|
|
log.Info("no corresponding fieldData pass in", zap.String("fieldSchema", fieldSchema.GetName()))
|
|
return merr.WrapErrParameterInvalidMsg("fieldSchema(%s) has no corresponding fieldData pass in", fieldSchema.GetName())
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// for some varchar with analzyer
|
|
// we need check char format before insert it to message queue
|
|
// now only support utf-8
|
|
func checkInputUtf8Compatiable(allFields []*schemapb.FieldSchema, insertMsg *msgstream.InsertMsg) error {
|
|
checkeFields := lo.FilterMap(allFields, func(field *schemapb.FieldSchema, _ int) (int64, bool) {
|
|
if field.DataType == schemapb.DataType_VarChar {
|
|
return field.GetFieldID(), true
|
|
}
|
|
|
|
if field.DataType != schemapb.DataType_Text {
|
|
return 0, false
|
|
}
|
|
|
|
for _, kv := range field.GetTypeParams() {
|
|
if kv.Key == common.EnableAnalyzerKey {
|
|
return field.GetFieldID(), true
|
|
}
|
|
}
|
|
return 0, false
|
|
})
|
|
|
|
if len(checkeFields) == 0 {
|
|
return nil
|
|
}
|
|
|
|
for _, fieldData := range insertMsg.FieldsData {
|
|
if !lo.Contains(checkeFields, fieldData.GetFieldId()) {
|
|
continue
|
|
}
|
|
|
|
strData := fieldData.GetScalars().GetStringData()
|
|
for row, data := range strData.GetData() {
|
|
ok := utf8.ValidString(data)
|
|
if !ok {
|
|
log.Warn("string field data not utf-8 format", zap.String("messageVersion", strData.ProtoReflect().Descriptor().Syntax().GoString()))
|
|
return merr.WrapErrAsInputError(fmt.Errorf("input with analyzer should be utf-8 format, but row: %d not utf-8 format. data: %s", row, data))
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func checkUpsertPrimaryFieldData(allFields []*schemapb.FieldSchema, schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, *schemapb.IDs, error) {
|
|
log := log.With(zap.String("collectionName", insertMsg.CollectionName))
|
|
rowNums := uint32(insertMsg.NRows())
|
|
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
|
|
if insertMsg.NRows() <= 0 {
|
|
return nil, nil, merr.WrapErrParameterInvalid("invalid num_rows", fmt.Sprint(rowNums), "num_rows should be greater than 0")
|
|
}
|
|
|
|
if err := checkFieldsDataBySchema(allFields, schema, insertMsg, false); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
|
|
if err != nil {
|
|
log.Error("get primary field schema failed", zap.Any("schema", schema), zap.Error(err))
|
|
return nil, nil, err
|
|
}
|
|
if primaryFieldSchema.GetNullable() {
|
|
return nil, nil, merr.WrapErrParameterInvalidMsg("primary field not support null")
|
|
}
|
|
// get primaryFieldData whether autoID is true or not
|
|
var primaryFieldData *schemapb.FieldData
|
|
var newPrimaryFieldData *schemapb.FieldData
|
|
|
|
primaryFieldID := primaryFieldSchema.FieldID
|
|
primaryFieldName := primaryFieldSchema.Name
|
|
for i, field := range insertMsg.GetFieldsData() {
|
|
if field.FieldId == primaryFieldID || field.FieldName == primaryFieldName {
|
|
primaryFieldData = field
|
|
if primaryFieldSchema.AutoID {
|
|
// use the passed pk as new pk when autoID == false
|
|
// automatic generate pk as new pk wehen autoID == true
|
|
newPrimaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
|
|
if err != nil {
|
|
log.Info("generate new primary field data failed when upsert", zap.Error(err))
|
|
return nil, nil, err
|
|
}
|
|
insertMsg.FieldsData = append(insertMsg.GetFieldsData()[:i], insertMsg.GetFieldsData()[i+1:]...)
|
|
insertMsg.FieldsData = append(insertMsg.FieldsData, newPrimaryFieldData)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
// must assign primary field data when upsert
|
|
if primaryFieldData == nil {
|
|
return nil, nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("must assign pk when upsert, primary field: %v", primaryFieldName))
|
|
}
|
|
|
|
// parse primaryFieldData to result.IDs, and as returned primary keys
|
|
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
|
|
if err != nil {
|
|
log.Warn("parse primary field data to IDs failed", zap.Error(err))
|
|
return nil, nil, err
|
|
}
|
|
if !primaryFieldSchema.GetAutoID() {
|
|
return ids, ids, nil
|
|
}
|
|
newIDs, err := parsePrimaryFieldData2IDs(newPrimaryFieldData)
|
|
if err != nil {
|
|
log.Warn("parse primary field data to IDs failed", zap.Error(err))
|
|
return nil, nil, err
|
|
}
|
|
return newIDs, ids, nil
|
|
}
|
|
|
|
func getPartitionKeyFieldData(fieldSchema *schemapb.FieldSchema, insertMsg *msgstream.InsertMsg) (*schemapb.FieldData, error) {
|
|
if len(insertMsg.GetPartitionName()) > 0 && !Params.ProxyCfg.SkipPartitionKeyCheck.GetAsBool() {
|
|
return nil, errors.New("not support manually specifying the partition names if partition key mode is used")
|
|
}
|
|
|
|
for _, fieldData := range insertMsg.GetFieldsData() {
|
|
if fieldData.GetFieldId() == fieldSchema.GetFieldID() {
|
|
return fieldData, nil
|
|
}
|
|
}
|
|
|
|
return nil, errors.New("partition key not specify when insert")
|
|
}
|
|
|
|
func getCollectionProgress(
|
|
ctx context.Context,
|
|
queryCoord types.QueryCoordClient,
|
|
msgBase *commonpb.MsgBase,
|
|
collectionID int64,
|
|
) (loadProgress int64, refreshProgress int64, err error) {
|
|
resp, err := queryCoord.ShowLoadCollections(ctx, &querypb.ShowCollectionsRequest{
|
|
Base: commonpbutil.UpdateMsgBase(
|
|
msgBase,
|
|
commonpbutil.WithMsgType(commonpb.MsgType_ShowCollections),
|
|
),
|
|
CollectionIDs: []int64{collectionID},
|
|
})
|
|
if err != nil {
|
|
log.Ctx(ctx).Warn("fail to show collections",
|
|
zap.Int64("collectionID", collectionID),
|
|
zap.Error(err),
|
|
)
|
|
return
|
|
}
|
|
|
|
err = merr.Error(resp.GetStatus())
|
|
if err != nil {
|
|
log.Ctx(ctx).Warn("fail to show collections",
|
|
zap.Int64("collectionID", collectionID),
|
|
zap.Error(err))
|
|
return
|
|
}
|
|
|
|
loadProgress = resp.GetInMemoryPercentages()[0]
|
|
if len(resp.GetRefreshProgress()) > 0 { // Compatibility for new Proxy with old QueryCoord
|
|
refreshProgress = resp.GetRefreshProgress()[0]
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func getPartitionProgress(
|
|
ctx context.Context,
|
|
queryCoord types.QueryCoordClient,
|
|
msgBase *commonpb.MsgBase,
|
|
partitionNames []string,
|
|
collectionName string,
|
|
collectionID int64,
|
|
dbName string,
|
|
) (loadProgress int64, refreshProgress int64, err error) {
|
|
IDs2Names := make(map[int64]string)
|
|
partitionIDs := make([]int64, 0)
|
|
for _, partitionName := range partitionNames {
|
|
var partitionID int64
|
|
partitionID, err = globalMetaCache.GetPartitionID(ctx, dbName, collectionName, partitionName)
|
|
if err != nil {
|
|
return
|
|
}
|
|
IDs2Names[partitionID] = partitionName
|
|
partitionIDs = append(partitionIDs, partitionID)
|
|
}
|
|
|
|
var resp *querypb.ShowPartitionsResponse
|
|
resp, err = queryCoord.ShowLoadPartitions(ctx, &querypb.ShowPartitionsRequest{
|
|
Base: commonpbutil.UpdateMsgBase(
|
|
msgBase,
|
|
commonpbutil.WithMsgType(commonpb.MsgType_ShowPartitions),
|
|
),
|
|
CollectionID: collectionID,
|
|
PartitionIDs: partitionIDs,
|
|
})
|
|
if err != nil {
|
|
log.Ctx(ctx).Warn("fail to show partitions", zap.Int64("collection_id", collectionID),
|
|
zap.String("collection_name", collectionName),
|
|
zap.Strings("partition_names", partitionNames),
|
|
zap.Error(err))
|
|
return
|
|
}
|
|
|
|
err = merr.Error(resp.GetStatus())
|
|
if err != nil {
|
|
err = merr.Error(resp.GetStatus())
|
|
log.Ctx(ctx).Warn("fail to show partitions",
|
|
zap.String("collectionName", collectionName),
|
|
zap.Strings("partitionNames", partitionNames),
|
|
zap.Error(err))
|
|
return
|
|
}
|
|
|
|
for _, p := range resp.InMemoryPercentages {
|
|
loadProgress += p
|
|
}
|
|
loadProgress /= int64(len(partitionIDs))
|
|
|
|
if len(resp.GetRefreshProgress()) > 0 { // Compatibility for new Proxy with old QueryCoord
|
|
refreshProgress = resp.GetRefreshProgress()[0]
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func isPartitionKeyMode(ctx context.Context, dbName string, colName string) (bool, error) {
|
|
colSchema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, colName)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
for _, fieldSchema := range colSchema.GetFields() {
|
|
if fieldSchema.IsPartitionKey {
|
|
return true, nil
|
|
}
|
|
}
|
|
|
|
return false, nil
|
|
}
|
|
|
|
func hasPartitionKeyModeField(schema *schemapb.CollectionSchema) bool {
|
|
for _, fieldSchema := range schema.GetFields() {
|
|
if fieldSchema.IsPartitionKey {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// getDefaultPartitionsInPartitionKeyMode only used in partition key mode
|
|
func getDefaultPartitionsInPartitionKeyMode(ctx context.Context, dbName string, collectionName string) ([]string, error) {
|
|
partitions, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Make sure the order of the partition names got every time is the same
|
|
partitionNames, _, err := typeutil.RearrangePartitionsForPartitionKey(partitions)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return partitionNames, nil
|
|
}
|
|
|
|
func assignChannelsByPK(pks *schemapb.IDs, channelNames []string, insertMsg *msgstream.InsertMsg) map[string][]int {
|
|
insertMsg.HashValues = typeutil.HashPK2Channels(pks, channelNames)
|
|
|
|
// groupedHashKeys represents the dmChannel index
|
|
channel2RowOffsets := make(map[string][]int) // channelName to count
|
|
// assert len(it.hashValues) < maxInt
|
|
for offset, channelID := range insertMsg.HashValues {
|
|
channelName := channelNames[channelID]
|
|
if _, ok := channel2RowOffsets[channelName]; !ok {
|
|
channel2RowOffsets[channelName] = []int{}
|
|
}
|
|
channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset)
|
|
}
|
|
|
|
return channel2RowOffsets
|
|
}
|
|
|
|
func assignPartitionKeys(ctx context.Context, dbName string, collName string, keys []*planpb.GenericValue) ([]string, error) {
|
|
partitionNames, err := globalMetaCache.GetPartitionsIndex(ctx, dbName, collName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
schema, err := globalMetaCache.GetCollectionSchema(ctx, dbName, collName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
partitionKeyFieldSchema, err := typeutil.GetPartitionKeyFieldSchema(schema.CollectionSchema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hashedPartitionNames, err := typeutil2.HashKey2Partitions(partitionKeyFieldSchema, keys, partitionNames)
|
|
return hashedPartitionNames, err
|
|
}
|
|
|
|
func ErrWithLog(logger *log.MLogger, msg string, err error) error {
|
|
wrapErr := errors.Wrap(err, msg)
|
|
if logger != nil {
|
|
logger.Warn(msg, zap.Error(err))
|
|
return wrapErr
|
|
}
|
|
log.Warn(msg, zap.Error(err))
|
|
return wrapErr
|
|
}
|
|
|
|
func verifyDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
|
|
for _, field := range insertMsg.FieldsData {
|
|
if field.GetFieldName() == common.MetaFieldName {
|
|
if !schema.EnableDynamicField {
|
|
return fmt.Errorf("without dynamic schema enabled, the field name cannot be set to %s", common.MetaFieldName)
|
|
}
|
|
for _, rowData := range field.GetScalars().GetJsonData().GetData() {
|
|
jsonData := make(map[string]interface{})
|
|
if err := json.Unmarshal(rowData, &jsonData); err != nil {
|
|
log.Info("insert invalid dynamic data, milvus only support json map",
|
|
zap.ByteString("data", rowData),
|
|
zap.Error(err),
|
|
)
|
|
return merr.WrapErrIoFailedReason(err.Error())
|
|
}
|
|
if _, ok := jsonData[common.MetaFieldName]; ok {
|
|
return fmt.Errorf("cannot set json key to: %s", common.MetaFieldName)
|
|
}
|
|
for _, f := range schema.GetFields() {
|
|
if _, ok := jsonData[f.GetName()]; ok {
|
|
log.Info("dynamic field name include the static field name", zap.String("fieldName", f.GetName()))
|
|
return fmt.Errorf("dynamic field name cannot include the static field name: %s", f.GetName())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func checkDynamicFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
|
|
for _, data := range insertMsg.FieldsData {
|
|
if data.IsDynamic {
|
|
data.FieldName = common.MetaFieldName
|
|
return verifyDynamicFieldData(schema, insertMsg)
|
|
}
|
|
}
|
|
defaultData := make([][]byte, insertMsg.NRows())
|
|
for i := range defaultData {
|
|
defaultData[i] = []byte("{}")
|
|
}
|
|
dynamicData := autoGenDynamicFieldData(defaultData)
|
|
insertMsg.FieldsData = append(insertMsg.FieldsData, dynamicData)
|
|
return nil
|
|
}
|
|
|
|
func addNamespaceData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
|
|
err := common.CheckNamespace(schema, insertMsg.InsertRequest.Namespace)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
namespaceEnabeld, _, err := common.ParseNamespaceProp(schema.Properties...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !namespaceEnabeld {
|
|
return nil
|
|
}
|
|
|
|
// check namespace field exists
|
|
namespaceField := typeutil.GetFieldByName(schema, common.NamespaceFieldName)
|
|
if namespaceField == nil {
|
|
return fmt.Errorf("namespace field not found")
|
|
}
|
|
|
|
// check namespace field data is already set
|
|
for _, fieldData := range insertMsg.FieldsData {
|
|
if fieldData.FieldId == namespaceField.FieldID {
|
|
return fmt.Errorf("namespace field data is already set by users")
|
|
}
|
|
}
|
|
|
|
// set namespace field data
|
|
namespaceData := make([]string, insertMsg.NRows())
|
|
namespace := *insertMsg.InsertRequest.Namespace
|
|
for i := range namespaceData {
|
|
namespaceData[i] = namespace
|
|
}
|
|
insertMsg.FieldsData = append(insertMsg.FieldsData, &schemapb.FieldData{
|
|
FieldName: namespaceField.Name,
|
|
FieldId: namespaceField.FieldID,
|
|
Type: namespaceField.DataType,
|
|
Field: &schemapb.FieldData_Scalars{
|
|
Scalars: &schemapb.ScalarField{
|
|
Data: &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: namespaceData,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
})
|
|
return nil
|
|
}
|
|
|
|
func GetCachedCollectionSchema(ctx context.Context, dbName string, colName string) (*schemaInfo, error) {
|
|
if globalMetaCache != nil {
|
|
return globalMetaCache.GetCollectionSchema(ctx, dbName, colName)
|
|
}
|
|
return nil, merr.WrapErrServiceNotReady(paramtable.GetRole(), paramtable.GetNodeID(), "initialization")
|
|
}
|
|
|
|
func CheckDatabase(ctx context.Context, dbName string) bool {
|
|
if globalMetaCache != nil {
|
|
return globalMetaCache.HasDatabase(ctx, dbName)
|
|
}
|
|
return false
|
|
}
|
|
|
|
func SetReportValue(status *commonpb.Status, value int) {
|
|
if value <= 0 {
|
|
return
|
|
}
|
|
if !merr.Ok(status) {
|
|
return
|
|
}
|
|
if status.ExtraInfo == nil {
|
|
status.ExtraInfo = make(map[string]string)
|
|
}
|
|
status.ExtraInfo["report_value"] = strconv.Itoa(value)
|
|
}
|
|
|
|
func SetStorageCost(status *commonpb.Status, storageCost segcore.StorageCost) {
|
|
if !Params.QueryNodeCfg.StorageUsageTrackingEnabled.GetAsBool() {
|
|
return
|
|
}
|
|
if storageCost.ScannedTotalBytes <= 0 {
|
|
return
|
|
}
|
|
if !merr.Ok(status) {
|
|
return
|
|
}
|
|
if status.ExtraInfo == nil {
|
|
status.ExtraInfo = make(map[string]string)
|
|
// set report_value to 0 for compatibility, when extra info is not nil, there are always the default report_value
|
|
// see https://github.com/milvus-io/pymilvus/pull/2999, pymilvus didn't check the report_value is set and use the value
|
|
status.ExtraInfo["report_value"] = strconv.Itoa(0)
|
|
}
|
|
|
|
status.ExtraInfo["scanned_remote_bytes"] = strconv.FormatInt(storageCost.ScannedRemoteBytes, 10)
|
|
status.ExtraInfo["scanned_total_bytes"] = strconv.FormatInt(storageCost.ScannedTotalBytes, 10)
|
|
cacheHitRatio := float64(storageCost.ScannedTotalBytes-storageCost.ScannedRemoteBytes) / float64(storageCost.ScannedTotalBytes)
|
|
status.ExtraInfo["cache_hit_ratio"] = strconv.FormatFloat(cacheHitRatio, 'f', -1, 64)
|
|
}
|
|
|
|
func GetCostValue(status *commonpb.Status) int {
|
|
if status == nil || status.ExtraInfo == nil {
|
|
return 0
|
|
}
|
|
value, err := strconv.Atoi(status.ExtraInfo["report_value"])
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return value
|
|
}
|
|
|
|
// final return value means value is valid or not
|
|
func GetStorageCost(status *commonpb.Status) (int64, int64, float64, bool) {
|
|
if status == nil || status.ExtraInfo == nil {
|
|
return 0, 0, 0, false
|
|
}
|
|
var scannedRemoteBytes int64
|
|
var scannedTotalBytes int64
|
|
var cacheHitRatio float64
|
|
var err error
|
|
if value, ok := status.ExtraInfo["scanned_remote_bytes"]; ok {
|
|
scannedRemoteBytes, err = strconv.ParseInt(value, 10, 64)
|
|
if err != nil {
|
|
log.Warn("scanned_remote_bytes is not a valid int64", zap.String("value", value), zap.Error(err))
|
|
return 0, 0, 0, false
|
|
}
|
|
} else {
|
|
return 0, 0, 0, false
|
|
}
|
|
if value, ok := status.ExtraInfo["scanned_total_bytes"]; ok {
|
|
scannedTotalBytes, err = strconv.ParseInt(value, 10, 64)
|
|
if err != nil {
|
|
log.Warn("scanned_total_bytes is not a valid int64", zap.String("value", value), zap.Error(err))
|
|
return 0, 0, 0, false
|
|
}
|
|
} else {
|
|
return 0, 0, 0, false
|
|
}
|
|
if value, ok := status.ExtraInfo["cache_hit_ratio"]; ok {
|
|
cacheHitRatio, err = strconv.ParseFloat(value, 64)
|
|
if err != nil {
|
|
log.Warn("cache_hit_ratio is not a valid float64", zap.String("value", value), zap.Error(err))
|
|
return 0, 0, 0, false
|
|
}
|
|
} else {
|
|
return 0, 0, 0, false
|
|
}
|
|
return scannedRemoteBytes, scannedTotalBytes, cacheHitRatio, true
|
|
}
|
|
|
|
// GetRequestInfo returns collection name and rateType of request and return tokens needed.
|
|
func GetRequestInfo(ctx context.Context, req proto.Message) (int64, map[int64][]int64, internalpb.RateType, int, error) {
|
|
switch r := req.(type) {
|
|
case *milvuspb.InsertRequest:
|
|
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
|
|
case *milvuspb.UpsertRequest:
|
|
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DMLInsert, proto.Size(r), err
|
|
case *milvuspb.DeleteRequest:
|
|
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DMLDelete, proto.Size(r), err
|
|
case *milvuspb.ImportRequest:
|
|
dbID, collToPartIDs, err := getCollectionAndPartitionID(ctx, req.(reqPartName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DMLBulkLoad, proto.Size(r), err
|
|
case *milvuspb.SearchRequest:
|
|
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
|
return dbID, collToPartIDs, internalpb.RateType_DQLSearch, int(r.GetNq()), err
|
|
case *milvuspb.QueryRequest:
|
|
dbID, collToPartIDs, err := getCollectionAndPartitionIDs(ctx, req.(reqPartNames))
|
|
return dbID, collToPartIDs, internalpb.RateType_DQLQuery, 1, err // think of the query request's nq as 1
|
|
case *milvuspb.CreateCollectionRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
|
case *milvuspb.DropCollectionRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
|
case *milvuspb.LoadCollectionRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
|
case *milvuspb.ReleaseCollectionRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLCollection, 1, nil
|
|
case *milvuspb.CreatePartitionRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
|
case *milvuspb.DropPartitionRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
|
case *milvuspb.LoadPartitionsRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
|
case *milvuspb.ReleasePartitionsRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLPartition, 1, nil
|
|
case *milvuspb.CreateIndexRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
|
case *milvuspb.DropIndexRequest:
|
|
dbID, collToPartIDs := getCollectionID(req.(reqCollName))
|
|
return dbID, collToPartIDs, internalpb.RateType_DDLIndex, 1, nil
|
|
case *milvuspb.FlushRequest:
|
|
db, err := globalMetaCache.GetDatabaseInfo(ctx, r.GetDbName())
|
|
if err != nil {
|
|
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
|
}
|
|
|
|
collToPartIDs := make(map[int64][]int64, 0)
|
|
for _, collectionName := range r.GetCollectionNames() {
|
|
collectionID, err := globalMetaCache.GetCollectionID(ctx, r.GetDbName(), collectionName)
|
|
if err != nil {
|
|
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
|
}
|
|
collToPartIDs[collectionID] = []int64{}
|
|
}
|
|
return db.dbID, collToPartIDs, internalpb.RateType_DDLFlush, 1, nil
|
|
case *milvuspb.ManualCompactionRequest:
|
|
dbName := GetCurDBNameFromContextOrDefault(ctx)
|
|
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, dbName)
|
|
if err != nil {
|
|
return util.InvalidDBID, map[int64][]int64{}, 0, 0, err
|
|
}
|
|
return dbInfo.dbID, map[int64][]int64{
|
|
r.GetCollectionID(): {},
|
|
}, internalpb.RateType_DDLCompaction, 1, nil
|
|
case *milvuspb.CreateDatabaseRequest:
|
|
log.Info("rate limiter CreateDatabaseRequest")
|
|
return util.InvalidDBID, map[int64][]int64{}, internalpb.RateType_DDLDB, 1, nil
|
|
case *milvuspb.DropDatabaseRequest:
|
|
log.Info("rate limiter DropDatabaseRequest")
|
|
return util.InvalidDBID, map[int64][]int64{}, internalpb.RateType_DDLDB, 1, nil
|
|
case *milvuspb.AlterDatabaseRequest:
|
|
return util.InvalidDBID, map[int64][]int64{}, internalpb.RateType_DDLDB, 1, nil
|
|
default: // TODO: support more request
|
|
if req == nil {
|
|
return util.InvalidDBID, map[int64][]int64{}, 0, 0, errors.New("null request")
|
|
}
|
|
log.RatedWarn(60, "not supported request type for rate limiter", zap.String("type", reflect.TypeOf(req).String()))
|
|
return util.InvalidDBID, map[int64][]int64{}, 0, 0, nil
|
|
}
|
|
}
|
|
|
|
// GetFailedResponse returns failed response.
|
|
func GetFailedResponse(req any, err error) any {
|
|
switch req.(type) {
|
|
case *milvuspb.InsertRequest, *milvuspb.DeleteRequest, *milvuspb.UpsertRequest:
|
|
return failedMutationResult(err)
|
|
case *milvuspb.ImportRequest:
|
|
return &milvuspb.ImportResponse{
|
|
Status: merr.Status(err),
|
|
}
|
|
case *milvuspb.SearchRequest:
|
|
return &milvuspb.SearchResults{
|
|
Status: merr.Status(err),
|
|
}
|
|
case *milvuspb.QueryRequest:
|
|
return &milvuspb.QueryResults{
|
|
Status: merr.Status(err),
|
|
}
|
|
case *milvuspb.CreateCollectionRequest, *milvuspb.DropCollectionRequest,
|
|
*milvuspb.LoadCollectionRequest, *milvuspb.ReleaseCollectionRequest,
|
|
*milvuspb.CreatePartitionRequest, *milvuspb.DropPartitionRequest,
|
|
*milvuspb.LoadPartitionsRequest, *milvuspb.ReleasePartitionsRequest,
|
|
*milvuspb.CreateIndexRequest, *milvuspb.DropIndexRequest,
|
|
*milvuspb.CreateDatabaseRequest, *milvuspb.DropDatabaseRequest,
|
|
*milvuspb.AlterDatabaseRequest:
|
|
return merr.Status(err)
|
|
case *milvuspb.FlushRequest:
|
|
return &milvuspb.FlushResponse{
|
|
Status: merr.Status(err),
|
|
}
|
|
case *milvuspb.ManualCompactionRequest:
|
|
return &milvuspb.ManualCompactionResponse{
|
|
Status: merr.Status(err),
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GetReplicateID(ctx context.Context, database, collectionName string) (string, error) {
|
|
if globalMetaCache == nil {
|
|
return "", merr.WrapErrServiceUnavailable("internal: Milvus Proxy is not ready yet. please wait")
|
|
}
|
|
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, database, collectionName, 0)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if colInfo.replicateID != "" {
|
|
return colInfo.replicateID, nil
|
|
}
|
|
dbInfo, err := globalMetaCache.GetDatabaseInfo(ctx, database)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
replicateID, _ := common.GetReplicateID(dbInfo.properties)
|
|
return replicateID, nil
|
|
}
|
|
|
|
func GetFunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
|
|
fields := make([]string, 0)
|
|
for _, fSchema := range collSchema.Functions {
|
|
fields = append(fields, fSchema.OutputFieldNames...)
|
|
}
|
|
return fields
|
|
}
|
|
|
|
func GetBM25FunctionOutputFields(collSchema *schemapb.CollectionSchema) []string {
|
|
fields := make([]string, 0)
|
|
for _, fSchema := range collSchema.Functions {
|
|
if fSchema.Type == schemapb.FunctionType_BM25 {
|
|
fields = append(fields, fSchema.OutputFieldNames...)
|
|
}
|
|
}
|
|
return fields
|
|
}
|
|
|
|
func getCollectionTTL(pairs []*commonpb.KeyValuePair) uint64 {
|
|
properties := make(map[string]string)
|
|
for _, pair := range pairs {
|
|
properties[pair.Key] = pair.Value
|
|
}
|
|
|
|
v, ok := properties[common.CollectionTTLConfigKey]
|
|
if ok {
|
|
ttl, err := strconv.Atoi(v)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return uint64(time.Duration(ttl) * time.Second)
|
|
}
|
|
|
|
return 0
|
|
}
|
|
|
|
// reconstructStructFieldDataCommon reconstructs struct fields from flattened sub-fields
|
|
// It works with both QueryResults and SearchResults by operating on the common data structures
|
|
func reconstructStructFieldDataCommon(
|
|
fieldsData []*schemapb.FieldData,
|
|
outputFields []string,
|
|
schema *schemapb.CollectionSchema,
|
|
) ([]*schemapb.FieldData, []string) {
|
|
if len(outputFields) == 1 && outputFields[0] == "count(*)" {
|
|
return fieldsData, outputFields
|
|
}
|
|
|
|
if len(schema.StructArrayFields) == 0 {
|
|
return fieldsData, outputFields
|
|
}
|
|
|
|
regularFieldIDs := make(map[int64]interface{})
|
|
subFieldToStructMap := make(map[int64]int64)
|
|
groupedStructFields := make(map[int64][]*schemapb.FieldData)
|
|
structFieldNames := make(map[int64]string)
|
|
reconstructedOutputFields := make([]string, 0, len(fieldsData))
|
|
|
|
// record all regular field IDs
|
|
for _, field := range schema.Fields {
|
|
regularFieldIDs[field.GetFieldID()] = nil
|
|
}
|
|
|
|
// build the mapping from sub-field ID to struct field ID
|
|
for _, structField := range schema.StructArrayFields {
|
|
for _, subField := range structField.GetFields() {
|
|
subFieldToStructMap[subField.GetFieldID()] = structField.GetFieldID()
|
|
}
|
|
structFieldNames[structField.GetFieldID()] = structField.GetName()
|
|
}
|
|
|
|
newFieldsData := make([]*schemapb.FieldData, 0, len(fieldsData))
|
|
for _, field := range fieldsData {
|
|
fieldID := field.GetFieldId()
|
|
if _, ok := regularFieldIDs[fieldID]; ok {
|
|
newFieldsData = append(newFieldsData, field)
|
|
reconstructedOutputFields = append(reconstructedOutputFields, field.GetFieldName())
|
|
} else {
|
|
structFieldID := subFieldToStructMap[fieldID]
|
|
groupedStructFields[structFieldID] = append(groupedStructFields[structFieldID], field)
|
|
}
|
|
}
|
|
|
|
for structFieldID, fields := range groupedStructFields {
|
|
// Create deep copies of fields to avoid modifying original data
|
|
// and restore original field names for user-facing response
|
|
copiedFields := make([]*schemapb.FieldData, len(fields))
|
|
for i, field := range fields {
|
|
copiedFields[i] = proto.Clone(field).(*schemapb.FieldData)
|
|
// Extract original field name from structName[fieldName] format
|
|
originalName, err := extractOriginalFieldName(copiedFields[i].FieldName)
|
|
if err != nil {
|
|
// This should not happen in normal operation - indicates a bug
|
|
log.Error("failed to extract original field name from struct field",
|
|
zap.String("fieldName", copiedFields[i].FieldName),
|
|
zap.Error(err))
|
|
// Keep the transformed name to avoid data corruption
|
|
} else {
|
|
copiedFields[i].FieldName = originalName
|
|
}
|
|
}
|
|
|
|
fieldData := &schemapb.FieldData{
|
|
FieldName: structFieldNames[structFieldID],
|
|
FieldId: structFieldID,
|
|
Type: schemapb.DataType_ArrayOfStruct,
|
|
Field: &schemapb.FieldData_StructArrays{StructArrays: &schemapb.StructArrayField{Fields: copiedFields}},
|
|
}
|
|
newFieldsData = append(newFieldsData, fieldData)
|
|
reconstructedOutputFields = append(reconstructedOutputFields, structFieldNames[structFieldID])
|
|
}
|
|
|
|
return newFieldsData, reconstructedOutputFields
|
|
}
|
|
|
|
// Wrapper for QueryResults
|
|
func reconstructStructFieldDataForQuery(results *milvuspb.QueryResults, schema *schemapb.CollectionSchema) {
|
|
fieldsData, outputFields := reconstructStructFieldDataCommon(
|
|
results.FieldsData,
|
|
results.OutputFields,
|
|
schema,
|
|
)
|
|
results.FieldsData = fieldsData
|
|
results.OutputFields = outputFields
|
|
}
|
|
|
|
// New wrapper for SearchResults
|
|
func reconstructStructFieldDataForSearch(results *milvuspb.SearchResults, schema *schemapb.CollectionSchema) {
|
|
if results.Results == nil {
|
|
return
|
|
}
|
|
fieldsData, outputFields := reconstructStructFieldDataCommon(
|
|
results.Results.FieldsData,
|
|
results.Results.OutputFields,
|
|
schema,
|
|
)
|
|
results.Results.FieldsData = fieldsData
|
|
results.Results.OutputFields = outputFields
|
|
}
|
|
|
|
func hasTimestamptzField(schema *schemapb.CollectionSchema) bool {
|
|
for _, field := range schema.Fields {
|
|
if field.GetDataType() == schemapb.DataType_Timestamptz {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func getDefaultTimezoneVal(props ...*commonpb.KeyValuePair) (bool, string) {
|
|
for _, p := range props {
|
|
// used in collection or database
|
|
if p.GetKey() == common.DatabaseDefaultTimezone || p.GetKey() == common.CollectionDefaultTimezone {
|
|
return true, p.Value
|
|
}
|
|
}
|
|
return false, ""
|
|
}
|
|
|
|
func checkTimezone(props ...*commonpb.KeyValuePair) error {
|
|
hasTImezone, timezoneStr := getDefaultTimezoneVal(props...)
|
|
if hasTImezone {
|
|
_, err := time.LoadLocation(timezoneStr)
|
|
if err != nil {
|
|
return merr.WrapErrParameterInvalidMsg("invalid timezone, should be a IANA timezone name: %s", err.Error())
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func getColTimezone(colInfo *collectionInfo) (bool, string) {
|
|
return getDefaultTimezoneVal(colInfo.properties...)
|
|
}
|
|
|
|
func getDbTimezone(dbInfo *databaseInfo) (bool, string) {
|
|
return getDefaultTimezoneVal(dbInfo.properties...)
|
|
}
|
|
|
|
func timestamptzIsoStr2Utc(columns []*schemapb.FieldData, colTimezone string) error {
|
|
naiveLayouts := []string{
|
|
"2006-01-02T15:04:05.999999999",
|
|
"2006-01-02T15:04:05",
|
|
"2006-01-02 15:04:05.999999999",
|
|
"2006-01-02 15:04:05",
|
|
}
|
|
for _, fieldData := range columns {
|
|
if fieldData.GetType() != schemapb.DataType_Timestamptz {
|
|
continue
|
|
}
|
|
|
|
scalarField := fieldData.GetScalars()
|
|
if scalarField == nil || scalarField.GetStringData() == nil {
|
|
log.Warn("field data is not string data", zap.String("fieldName", fieldData.GetFieldName()))
|
|
return merr.WrapErrParameterInvalidMsg("field data is not string data")
|
|
}
|
|
|
|
stringData := scalarField.GetStringData().GetData()
|
|
utcTimestamps := make([]int64, len(stringData))
|
|
|
|
for i, isoStr := range stringData {
|
|
var t time.Time
|
|
var err error
|
|
// parse directly
|
|
t, err = time.Parse(time.RFC3339Nano, isoStr)
|
|
if err == nil {
|
|
utcTimestamps[i] = t.UnixMicro()
|
|
continue
|
|
}
|
|
// no timezone, try to find timezone in collecion -> database level
|
|
defaultTZ := "UTC"
|
|
if colTimezone != "" {
|
|
defaultTZ = colTimezone
|
|
}
|
|
|
|
location, err := time.LoadLocation(defaultTZ)
|
|
if err != nil {
|
|
log.Error("invalid timezone", zap.String("timezone", defaultTZ), zap.Error(err))
|
|
return merr.WrapErrParameterInvalidMsg("got invalid default timezone: %s", defaultTZ)
|
|
}
|
|
var parsed bool
|
|
for _, layout := range naiveLayouts {
|
|
t, err = time.ParseInLocation(layout, isoStr, location)
|
|
if err == nil {
|
|
parsed = true
|
|
break
|
|
}
|
|
}
|
|
if !parsed {
|
|
log.Warn("Can not parse timestamptz string", zap.String("timestamp_string", isoStr))
|
|
return merr.WrapErrParameterInvalidMsg("got invalid timestamptz string: %s", isoStr)
|
|
}
|
|
utcTimestamps[i] = t.UnixMicro()
|
|
}
|
|
// Replace data in place
|
|
fieldData.GetScalars().Data = &schemapb.ScalarField_TimestamptzData{
|
|
TimestamptzData: &schemapb.TimestamptzArray{
|
|
Data: utcTimestamps,
|
|
},
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func timestamptzUTC2IsoStr(results []*schemapb.FieldData, userDefineTimezone string, colTimezone string) error {
|
|
// Determine the target timezone based on priority: collection -> database -> UTC.
|
|
defaultTZ := "UTC"
|
|
if userDefineTimezone != "" {
|
|
defaultTZ = userDefineTimezone
|
|
} else if colTimezone != "" {
|
|
defaultTZ = colTimezone
|
|
}
|
|
|
|
location, err := time.LoadLocation(defaultTZ)
|
|
if err != nil {
|
|
log.Error("invalid timezone", zap.String("timezone", defaultTZ), zap.Error(err))
|
|
return merr.WrapErrParameterInvalidMsg("got invalid default timezone: %s", defaultTZ)
|
|
}
|
|
|
|
for _, fieldData := range results {
|
|
if fieldData.GetType() != schemapb.DataType_Timestamptz {
|
|
continue
|
|
}
|
|
scalarField := fieldData.GetScalars()
|
|
if scalarField == nil || scalarField.GetTimestamptzData() == nil {
|
|
if longData := scalarField.GetLongData(); longData != nil && len(longData.GetData()) > 0 {
|
|
log.Warn("field data is not Timestamptz data", zap.String("fieldName", fieldData.GetFieldName()))
|
|
return merr.WrapErrParameterInvalidMsg("field data for '%s' is not Timestamptz data", fieldData.GetFieldName())
|
|
}
|
|
}
|
|
|
|
utcTimestamps := scalarField.GetTimestamptzData().GetData()
|
|
isoStrings := make([]string, len(utcTimestamps))
|
|
|
|
for i, ts := range utcTimestamps {
|
|
t := time.UnixMicro(ts).UTC()
|
|
localTime := t.In(location)
|
|
isoStrings[i] = localTime.Format(time.RFC3339Nano)
|
|
}
|
|
|
|
// Replace the TimestamptzData with the new StringData in place.
|
|
fieldData.GetScalars().Data = &schemapb.ScalarField_StringData{
|
|
StringData: &schemapb.StringArray{
|
|
Data: isoStrings,
|
|
},
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// extractFields is a helper function to extract specific integer fields from a time.Time object.
|
|
// Supported fields are: "year", "month", "day", "hour", "minute", "second", "microsecond", "nanosecond".
|
|
func extractFields(t time.Time, fieldList []string) ([]int64, error) {
|
|
extractedValues := make([]int64, 0, len(fieldList))
|
|
for _, field := range fieldList {
|
|
var val int64
|
|
switch strings.ToLower(field) {
|
|
case common.TszYear:
|
|
val = int64(t.Year())
|
|
case common.TszMonth:
|
|
val = int64(t.Month())
|
|
case common.TszDay:
|
|
val = int64(t.Day())
|
|
case common.TszHour:
|
|
val = int64(t.Hour())
|
|
case common.TszMinute:
|
|
val = int64(t.Minute())
|
|
case common.TszSecond:
|
|
val = int64(t.Second())
|
|
case common.TszMicrosecond:
|
|
val = int64(t.Nanosecond() / 1000)
|
|
default:
|
|
return nil, merr.WrapErrParameterInvalidMsg("unsupported field for extraction: %s, fields should be seprated by ',' or ' '", field)
|
|
}
|
|
extractedValues = append(extractedValues, val)
|
|
}
|
|
return extractedValues, nil
|
|
}
|
|
|
|
func extractFieldsFromResults(results []*schemapb.FieldData, precedenceTimezone []string, fieldList []string) error {
|
|
var targetLocation *time.Location
|
|
|
|
for _, tz := range precedenceTimezone {
|
|
if tz != "" {
|
|
loc, err := time.LoadLocation(tz)
|
|
if err != nil {
|
|
log.Error("invalid timezone provided in precedence list", zap.String("timezone", tz), zap.Error(err))
|
|
return merr.WrapErrParameterInvalidMsg("got invalid timezone: %s", tz)
|
|
}
|
|
targetLocation = loc
|
|
break // Use the first valid timezone found.
|
|
}
|
|
}
|
|
|
|
if targetLocation == nil {
|
|
targetLocation = time.UTC
|
|
}
|
|
|
|
for _, fieldData := range results {
|
|
if fieldData.GetType() != schemapb.DataType_Timestamptz {
|
|
continue
|
|
}
|
|
|
|
scalarField := fieldData.GetScalars()
|
|
if scalarField == nil || scalarField.GetTimestamptzData() == nil {
|
|
if longData := scalarField.GetLongData(); longData != nil && len(longData.GetData()) > 0 {
|
|
log.Warn("field data is not Timestamptz data, but found LongData instead", zap.String("fieldName", fieldData.GetFieldName()))
|
|
return merr.WrapErrParameterInvalidMsg("field data for '%s' is not Timestamptz data", fieldData.GetFieldName())
|
|
}
|
|
continue
|
|
}
|
|
|
|
utcTimestamps := scalarField.GetTimestamptzData().GetData()
|
|
extractedResults := make([]*schemapb.ScalarField, 0, len(fieldList))
|
|
|
|
for _, ts := range utcTimestamps {
|
|
t := time.UnixMicro(ts).UTC()
|
|
localTime := t.In(targetLocation)
|
|
|
|
values, err := extractFields(localTime, fieldList)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
valuesScalarField := &schemapb.ScalarField_LongData{
|
|
LongData: &schemapb.LongArray{
|
|
Data: values,
|
|
},
|
|
}
|
|
extractedResults = append(extractedResults, &schemapb.ScalarField{
|
|
Data: valuesScalarField,
|
|
})
|
|
}
|
|
|
|
fieldData.GetScalars().Data = &schemapb.ScalarField_ArrayData{
|
|
ArrayData: &schemapb.ArrayArray{
|
|
Data: extractedResults,
|
|
ElementType: schemapb.DataType_Int64,
|
|
},
|
|
}
|
|
fieldData.Type = schemapb.DataType_Array
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func genFunctionFields(ctx context.Context, insertMsg *msgstream.InsertMsg, schema *schemaInfo, partialUpdate bool) error {
|
|
allowNonBM25Outputs := common.GetCollectionAllowInsertNonBM25FunctionOutputs(schema.Properties)
|
|
fieldIDs := lo.Map(insertMsg.FieldsData, func(fieldData *schemapb.FieldData, _ int) int64 {
|
|
id, _ := schema.MapFieldID(fieldData.FieldName)
|
|
return id
|
|
})
|
|
|
|
// Since PartialUpdate is supported, the field_data here may not be complete
|
|
needProcessFunctions, err := typeutil.GetNeedProcessFunctions(fieldIDs, schema.Functions, allowNonBM25Outputs, partialUpdate)
|
|
if err != nil {
|
|
log.Ctx(ctx).Warn("Check upsert field error,", zap.String("collectionName", schema.Name), zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
if embedding.HasNonBM25Functions(schema.CollectionSchema.Functions, []int64{}) {
|
|
ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-genFunctionFields-call-function-udf")
|
|
defer sp.End()
|
|
exec, err := embedding.NewFunctionExecutor(schema.CollectionSchema, needProcessFunctions)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
sp.AddEvent("Create-function-udf")
|
|
if err := exec.ProcessInsert(ctx, insertMsg); err != nil {
|
|
return err
|
|
}
|
|
sp.AddEvent("Call-function-udf")
|
|
}
|
|
return nil
|
|
}
|