mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
issue: https://github.com/milvus-io/milvus/issues/27467 >My plan is as follows. >- [x] M1 Create collection with timestamptz field >- [x] M2 Insert timestamptz field data >- [x] M3 Retrieve timestamptz field data >- [x] M4 Implement handoff >- [x] M5 Implement compare operator >- [x] M6 Implement extract operator >- [x] M8 Support database/collection level default timezone >- [x] M7 Support STL-SORT index for datatype timestamptz --- The third PR of issue: https://github.com/milvus-io/milvus/issues/27467, which completes M5, M6, M7, M8 described above. ## M8 Default Timezone We will be able to use alter_collection() and alter_database() in a future Python SDK release to modify the default timezone at the collection or database level. For insert requests, the timezone will be resolved using the following order of precedence: String Literal-> Collection Default -> Database Default. For retrieval requests, the timezone will be resolved in this order: Query Parameters -> Collection Default -> Database Default. In both cases, the final fallback timezone is UTC. ## M5: Comparison Operators We can now use the following expression format to filter on the timestamptz field: - `timestamptz_field [+/- INTERVAL 'interval_string'] {comparison_op} ISO 'iso_string' ` - The interval_string follows the ISO 8601 duration format, for example: P1Y2M3DT1H2M3S. - The iso_string follows the ISO 8601 timestamp format, for example: 2025-01-03T00:00:00+08:00. - Example expressions: "tsz + INTERVAL 'P0D' != ISO '2025-01-03T00:00:00+08:00'" or "tsz != ISO '2025-01-03T00:00:00+08:00'". ## M6: Extract We will be able to extract sepecific time filed by kwargs in a future Python SDK release. The key is `time_fields`, and value should be one or more of "year, month, day, hour, minute, second, microsecond", seperated by comma or space. Then the result of each record would be an array of int64. ## M7: Indexing Support Expressions without interval arithmetic can be accelerated using an STL-SORT index. However, expressions that include interval arithmetic cannot be indexed. This is because the result of an interval calculation depends on the specific timestamp value. For example, adding one month to a date in February results in a different number of added days than adding one month to a date in March. --- After this PR, the input / output type of timestamptz would be iso string. Timestampz would be stored as timestamptz data, which is int64_t finally. > for more information, see https://en.wikipedia.org/wiki/ISO_8601 --------- Signed-off-by: xtx <xtianx@smail.nju.edu.cn>
703 lines
20 KiB
Go
703 lines
20 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/cockroachdb/errors"
|
|
"github.com/google/uuid"
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
|
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/common"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
|
)
|
|
|
|
type rankParams struct {
|
|
limit int64
|
|
offset int64
|
|
roundDecimal int64
|
|
groupByFieldId int64
|
|
groupSize int64
|
|
strictGroupSize bool
|
|
}
|
|
|
|
func (r *rankParams) GetLimit() int64 {
|
|
if r != nil {
|
|
return r.limit
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (r *rankParams) GetOffset() int64 {
|
|
if r != nil {
|
|
return r.offset
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (r *rankParams) GetRoundDecimal() int64 {
|
|
if r != nil {
|
|
return r.roundDecimal
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (r *rankParams) GetGroupByFieldId() int64 {
|
|
if r != nil {
|
|
return r.groupByFieldId
|
|
}
|
|
return -1
|
|
}
|
|
|
|
func (r *rankParams) GetGroupSize() int64 {
|
|
if r != nil {
|
|
return r.groupSize
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func (r *rankParams) GetStrictGroupSize() bool {
|
|
if r != nil {
|
|
return r.strictGroupSize
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (r *rankParams) String() string {
|
|
return fmt.Sprintf("limit: %d, offset: %d, roundDecimal: %d", r.GetLimit(), r.GetOffset(), r.GetRoundDecimal())
|
|
}
|
|
|
|
type SearchInfo struct {
|
|
planInfo *planpb.QueryInfo
|
|
offset int64
|
|
isIterator bool
|
|
collectionID int64
|
|
}
|
|
|
|
func parseSearchIteratorV2Info(searchParamsPair []*commonpb.KeyValuePair, groupByFieldId int64, isIterator bool, offset int64, queryTopK *int64) (*planpb.SearchIteratorV2Info, error) {
|
|
isIteratorV2Str, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterV2Key, searchParamsPair)
|
|
isIteratorV2, _ := strconv.ParseBool(isIteratorV2Str)
|
|
if !isIteratorV2 {
|
|
return nil, nil
|
|
}
|
|
|
|
// iteratorV1 and iteratorV2 should be set together for compatibility
|
|
if !isIterator {
|
|
return nil, fmt.Errorf("both %s and %s must be set in the SDK", IteratorField, SearchIterV2Key)
|
|
}
|
|
|
|
// disable groupBy when doing iteratorV2
|
|
// same behavior with V1
|
|
if isIteratorV2 && groupByFieldId > 0 {
|
|
return nil, merr.WrapErrParameterInvalid("", "",
|
|
"GroupBy is not permitted when using a search iterator")
|
|
}
|
|
|
|
// disable offset when doing iteratorV2
|
|
if isIteratorV2 && offset > 0 {
|
|
return nil, merr.WrapErrParameterInvalid("", "",
|
|
"Setting an offset is not permitted when using a search iterator v2")
|
|
}
|
|
|
|
// parse token, generate if not exist
|
|
token, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterIdKey, searchParamsPair)
|
|
if token == "" {
|
|
generatedToken, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
token = generatedToken.String()
|
|
} else {
|
|
// Validate existing token is a valid UUID
|
|
if _, err := uuid.Parse(token); err != nil {
|
|
return nil, errors.New("invalid token format")
|
|
}
|
|
}
|
|
|
|
// parse batch size, required non-zero value
|
|
batchSizeStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterBatchSizeKey, searchParamsPair)
|
|
if batchSizeStr == "" {
|
|
return nil, errors.New("batch size is required")
|
|
}
|
|
batchSize, err := strconv.ParseInt(batchSizeStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("batch size is invalid, %w", err)
|
|
}
|
|
// use the same validation logic as topk
|
|
if err := validateLimit(batchSize); err != nil {
|
|
return nil, fmt.Errorf("batch size is invalid, %w", err)
|
|
}
|
|
*queryTopK = batchSize // for compatibility
|
|
|
|
// prepare plan iterator v2 info proto
|
|
planIteratorV2Info := &planpb.SearchIteratorV2Info{
|
|
Token: token,
|
|
BatchSize: uint32(batchSize),
|
|
}
|
|
|
|
// append optional last bound if applicable
|
|
lastBoundStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(SearchIterLastBoundKey, searchParamsPair)
|
|
if lastBoundStr != "" {
|
|
lastBound, err := strconv.ParseFloat(lastBoundStr, 32)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to parse input last bound, %w", err)
|
|
}
|
|
lastBoundFloat32 := float32(lastBound)
|
|
planIteratorV2Info.LastBound = &lastBoundFloat32 // escape pointer
|
|
}
|
|
|
|
return planIteratorV2Info, nil
|
|
}
|
|
|
|
// parseSearchInfo returns QueryInfo and offset
|
|
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*SearchInfo, error) {
|
|
var topK int64
|
|
isAdvanced := rankParams != nil
|
|
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
|
|
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
|
|
if err != nil {
|
|
if externalLimit <= 0 {
|
|
return nil, fmt.Errorf("%s is required", TopKKey)
|
|
}
|
|
topK = externalLimit
|
|
} else {
|
|
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
|
|
if err != nil {
|
|
if externalLimit <= 0 {
|
|
return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
|
|
}
|
|
topK = externalLimit
|
|
} else {
|
|
topK = topKInParam
|
|
}
|
|
}
|
|
|
|
isIteratorStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)
|
|
isIterator := (isIteratorStr == "True") || (isIteratorStr == "true")
|
|
|
|
collectionIDStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, searchParamsPair)
|
|
collectionId, _ := strconv.ParseInt(collectionIDStr, 0, 64)
|
|
|
|
if err := validateLimit(topK); err != nil {
|
|
if isIterator {
|
|
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
|
|
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
|
|
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
|
|
} else {
|
|
return nil, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
|
|
}
|
|
}
|
|
|
|
var offset int64
|
|
// ignore offset if isAdvanced
|
|
if !isAdvanced {
|
|
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, searchParamsPair)
|
|
if err == nil {
|
|
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
|
}
|
|
|
|
if offset != 0 {
|
|
if err := validateLimit(offset); err != nil {
|
|
return nil, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
queryTopK := topK + offset
|
|
if err := validateLimit(queryTopK); err != nil {
|
|
return nil, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
|
|
}
|
|
|
|
// 2. parse metrics type
|
|
metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(common.MetricTypeKey, searchParamsPair)
|
|
if err != nil {
|
|
metricType = ""
|
|
}
|
|
|
|
// 3. parse round decimal
|
|
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, searchParamsPair)
|
|
if err != nil {
|
|
roundDecimalStr = "-1"
|
|
}
|
|
|
|
hints, err := funcutil.GetAttrByKeyFromRepeatedKV(common.HintsKey, searchParamsPair)
|
|
if err != nil {
|
|
hints = ""
|
|
}
|
|
|
|
roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
// 4. parse search param str
|
|
searchParamStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ParamsKey, searchParamsPair)
|
|
if err != nil {
|
|
searchParamStr = ""
|
|
}
|
|
|
|
// 5. parse group by field and group by size
|
|
var groupByFieldId, groupSize int64
|
|
var strictGroupSize bool
|
|
var jsonPath string
|
|
var jsonType schemapb.DataType
|
|
var strictCast bool
|
|
if isAdvanced {
|
|
groupByFieldId, groupSize, strictGroupSize = rankParams.GetGroupByFieldId(), rankParams.GetGroupSize(), rankParams.GetStrictGroupSize()
|
|
} else {
|
|
groupByInfo, err := parseGroupByInfo(searchParamsPair, schema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
groupByFieldId, groupSize, strictGroupSize = groupByInfo.GetGroupByFieldId(), groupByInfo.GetGroupSize(), groupByInfo.GetStrictGroupSize()
|
|
jsonPath, jsonType, strictCast = groupByInfo.GetJSONPath(), groupByInfo.GetJSONType(), groupByInfo.GetStrictCast()
|
|
if jsonPath != "" {
|
|
jsonPath, err = typeutil2.ParseAndVerifyNestedPath(jsonPath, schema, groupByFieldId)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
// 6. parse iterator tag, prevent trying to groupBy when doing iteration or doing range-search
|
|
if isIterator && groupByFieldId > 0 {
|
|
return nil, merr.WrapErrParameterInvalid("", "",
|
|
"Not allowed to do groupBy when doing iteration")
|
|
}
|
|
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
|
|
return nil, merr.WrapErrParameterInvalid("", "",
|
|
"Not allowed to do range-search when doing search-group-by")
|
|
}
|
|
|
|
planSearchIteratorV2Info, err := parseSearchIteratorV2Info(searchParamsPair, groupByFieldId, isIterator, offset, &queryTopK)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse iterator v2 info failed: %w", err)
|
|
}
|
|
|
|
// 7. check search for embedding list
|
|
annsFieldName, _ := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, searchParamsPair)
|
|
if annsFieldName != "" {
|
|
annField := typeutil.GetFieldByName(schema, annsFieldName)
|
|
if annField != nil && annField.GetDataType() == schemapb.DataType_ArrayOfVector {
|
|
if strings.Contains(searchParamStr, radiusKey) {
|
|
return nil, merr.WrapErrParameterInvalid("", "",
|
|
"range search is not supported for vector array (embedding list) fields, fieldName:", annsFieldName)
|
|
}
|
|
|
|
if groupByFieldId > 0 {
|
|
return nil, merr.WrapErrParameterInvalid("", "",
|
|
"group by search is not supported for vector array (embedding list) fields, fieldName:", annsFieldName)
|
|
}
|
|
|
|
if isIterator {
|
|
return nil, merr.WrapErrParameterInvalid("", "",
|
|
"search iterator is not supported for vector array (embedding list) fields, fieldName:", annsFieldName)
|
|
}
|
|
}
|
|
}
|
|
|
|
return &SearchInfo{
|
|
planInfo: &planpb.QueryInfo{
|
|
Topk: queryTopK,
|
|
MetricType: metricType,
|
|
SearchParams: searchParamStr,
|
|
RoundDecimal: roundDecimal,
|
|
GroupByFieldId: groupByFieldId,
|
|
GroupSize: groupSize,
|
|
StrictGroupSize: strictGroupSize,
|
|
Hints: hints,
|
|
SearchIteratorV2Info: planSearchIteratorV2Info,
|
|
JsonPath: jsonPath,
|
|
JsonType: jsonType,
|
|
StrictCast: strictCast,
|
|
},
|
|
offset: offset,
|
|
isIterator: isIterator,
|
|
collectionID: collectionId,
|
|
}, nil
|
|
}
|
|
|
|
func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
|
|
outputFieldIDs = make([]UniqueID, 0, len(outputFields))
|
|
for _, name := range outputFields {
|
|
id, ok := schema.MapFieldID(name)
|
|
if !ok {
|
|
return nil, fmt.Errorf("Field %s not exist", name)
|
|
}
|
|
outputFieldIDs = append(outputFieldIDs, id)
|
|
}
|
|
return outputFieldIDs, nil
|
|
}
|
|
|
|
func getNqFromSubSearch(req *milvuspb.SubSearchRequest) (int64, error) {
|
|
if req.GetNq() == 0 {
|
|
// keep compatible with older client version.
|
|
x := &commonpb.PlaceholderGroup{}
|
|
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total := int64(0)
|
|
for _, h := range x.GetPlaceholders() {
|
|
total += int64(len(h.Values))
|
|
}
|
|
return total, nil
|
|
}
|
|
return req.GetNq(), nil
|
|
}
|
|
|
|
func getNq(req *milvuspb.SearchRequest) (int64, error) {
|
|
if req.GetNq() == 0 {
|
|
// keep compatible with older client version.
|
|
x := &commonpb.PlaceholderGroup{}
|
|
err := proto.Unmarshal(req.GetPlaceholderGroup(), x)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
total := int64(0)
|
|
for _, h := range x.GetPlaceholders() {
|
|
total += int64(len(h.Values))
|
|
}
|
|
return total, nil
|
|
}
|
|
return req.GetNq(), nil
|
|
}
|
|
|
|
func getPartitionIDs(ctx context.Context, dbName string, collectionName string, partitionNames []string) (partitionIDs []UniqueID, err error) {
|
|
for _, tag := range partitionNames {
|
|
if err := validatePartitionTag(tag, false); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
partitionsMap, err := globalMetaCache.GetPartitions(ctx, dbName, collectionName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
useRegexp := Params.ProxyCfg.PartitionNameRegexp.GetAsBool()
|
|
|
|
partitionsSet := typeutil.NewUniqueSet()
|
|
for _, partitionName := range partitionNames {
|
|
if useRegexp {
|
|
// Legacy feature, use partition name as regexp
|
|
pattern := fmt.Sprintf("^%s$", partitionName)
|
|
re, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid partition: %s", partitionName)
|
|
}
|
|
var found bool
|
|
for name, pID := range partitionsMap {
|
|
if re.MatchString(name) {
|
|
partitionsSet.Insert(pID)
|
|
found = true
|
|
}
|
|
}
|
|
if !found {
|
|
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
|
}
|
|
} else {
|
|
partitionID, found := partitionsMap[partitionName]
|
|
if !found {
|
|
// TODO change after testcase updated: return nil, merr.WrapErrPartitionNotFound(partitionName)
|
|
return nil, fmt.Errorf("partition name %s not found", partitionName)
|
|
}
|
|
partitionsSet.Insert(partitionID)
|
|
}
|
|
}
|
|
return partitionsSet.Collect(), nil
|
|
}
|
|
|
|
type groupByInfo struct {
|
|
groupByFieldId int64
|
|
groupSize int64
|
|
strictGroupSize bool
|
|
jsonPath string
|
|
jsonType schemapb.DataType
|
|
strictCast bool
|
|
}
|
|
|
|
func (g *groupByInfo) GetGroupByFieldId() int64 {
|
|
if g != nil {
|
|
return g.groupByFieldId
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (g *groupByInfo) GetGroupSize() int64 {
|
|
if g != nil {
|
|
return g.groupSize
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func (g *groupByInfo) GetStrictGroupSize() bool {
|
|
if g != nil {
|
|
return g.strictGroupSize
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (g *groupByInfo) GetJSONPath() string {
|
|
if g != nil {
|
|
return g.jsonPath
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (g *groupByInfo) GetJSONType() schemapb.DataType {
|
|
if g != nil {
|
|
return g.jsonType
|
|
}
|
|
return schemapb.DataType_None
|
|
}
|
|
|
|
func (g *groupByInfo) GetStrictCast() bool {
|
|
if g != nil {
|
|
return g.strictCast
|
|
}
|
|
return false
|
|
}
|
|
|
|
func parseGroupByInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*groupByInfo, error) {
|
|
ret := &groupByInfo{}
|
|
|
|
// 1. parse group_by_field
|
|
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
|
if err != nil {
|
|
groupByFieldName = ""
|
|
}
|
|
var groupByFieldId int64 = -1
|
|
if groupByFieldName != "" {
|
|
fields := schema.GetFields()
|
|
var dynamicField *schemapb.FieldSchema
|
|
for _, field := range fields {
|
|
if field.Name == groupByFieldName {
|
|
groupByFieldId = field.FieldID
|
|
break
|
|
}
|
|
if field.GetIsDynamic() {
|
|
dynamicField = field
|
|
break
|
|
}
|
|
}
|
|
if groupByFieldId == -1 {
|
|
if dynamicField != nil {
|
|
groupByFieldId = dynamicField.FieldID
|
|
} else {
|
|
return nil, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
|
|
}
|
|
}
|
|
}
|
|
ret.groupByFieldId = groupByFieldId
|
|
|
|
// 2. parse group size
|
|
var groupSize int64
|
|
groupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupSizeKey, searchParamsPair)
|
|
if err != nil {
|
|
groupSize = 1
|
|
} else {
|
|
groupSize, err = strconv.ParseInt(groupSizeStr, 0, 64)
|
|
if err != nil {
|
|
return nil, merr.WrapErrParameterInvalidMsg(
|
|
fmt.Sprintf("failed to parse input group size:%s", groupSizeStr))
|
|
}
|
|
if groupSize <= 0 {
|
|
return nil, merr.WrapErrParameterInvalidMsg(
|
|
fmt.Sprintf("input group size:%d is negative, failed to do search_groupby", groupSize))
|
|
}
|
|
}
|
|
if groupSize > Params.QuotaConfig.MaxGroupSize.GetAsInt64() {
|
|
return nil, merr.WrapErrParameterInvalidMsg(
|
|
fmt.Sprintf("input group size:%d exceeds configured max group size:%d", groupSize, Params.QuotaConfig.MaxGroupSize.GetAsInt64()))
|
|
}
|
|
ret.groupSize = groupSize
|
|
|
|
// 3. parse group strict size
|
|
var strictGroupSize bool
|
|
strictGroupSizeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(StrictGroupSize, searchParamsPair)
|
|
if err != nil {
|
|
strictGroupSize = false
|
|
} else {
|
|
strictGroupSize, err = strconv.ParseBool(strictGroupSizeStr)
|
|
if err != nil {
|
|
strictGroupSize = false
|
|
}
|
|
}
|
|
ret.strictGroupSize = strictGroupSize
|
|
|
|
// 4. parse json path
|
|
jsonPath, err := funcutil.GetAttrByKeyFromRepeatedKV(JSONPath, searchParamsPair)
|
|
if err == nil {
|
|
ret.jsonPath = jsonPath
|
|
}
|
|
|
|
// 5. parse json type
|
|
jsonTypeStr, err := funcutil.GetAttrByKeyFromRepeatedKV(JSONType, searchParamsPair)
|
|
if err == nil {
|
|
dataTypeVal, ok := schemapb.DataType_value[jsonTypeStr]
|
|
if ok {
|
|
ret.jsonType = schemapb.DataType(dataTypeVal)
|
|
}
|
|
}
|
|
|
|
// 6. parse strict cast
|
|
strictCastStr, err := funcutil.GetAttrByKeyFromRepeatedKV(StrictCastKey, searchParamsPair)
|
|
if err == nil {
|
|
strictCast, err := strconv.ParseBool(strictCastStr)
|
|
if err != nil {
|
|
strictCast = false
|
|
}
|
|
ret.strictCast = strictCast
|
|
}
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
// parseRankParams get limit and offset from rankParams, both are optional.
|
|
func parseRankParams(rankParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema) (*rankParams, error) {
|
|
var (
|
|
limit int64
|
|
offset int64
|
|
roundDecimal int64
|
|
err error
|
|
)
|
|
|
|
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, rankParamsPair)
|
|
if err != nil {
|
|
return nil, errors.New(LimitKey + " not found in rank_params")
|
|
}
|
|
limit, err = strconv.ParseInt(limitStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
|
|
}
|
|
|
|
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, rankParamsPair)
|
|
if err == nil {
|
|
offset, err = strconv.ParseInt(offsetStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
|
|
}
|
|
}
|
|
|
|
// validate max result window.
|
|
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
|
|
return nil, fmt.Errorf("invalid max query result window, %w", err)
|
|
}
|
|
|
|
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, rankParamsPair)
|
|
if err != nil {
|
|
roundDecimalStr = "-1"
|
|
}
|
|
|
|
roundDecimal, err = strconv.ParseInt(roundDecimalStr, 0, 64)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
|
|
return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
|
|
}
|
|
|
|
// parse group_by parameters from main request body for hybrid search
|
|
groupByInfo, err := parseGroupByInfo(rankParamsPair, schema)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &rankParams{
|
|
limit: limit,
|
|
offset: offset,
|
|
roundDecimal: roundDecimal,
|
|
groupByFieldId: groupByInfo.GetGroupByFieldId(),
|
|
groupSize: groupByInfo.GetGroupSize(),
|
|
strictGroupSize: groupByInfo.GetStrictGroupSize(),
|
|
}, nil
|
|
}
|
|
|
|
func parseTimezone(params []*commonpb.KeyValuePair) string {
|
|
timezone, err := funcutil.GetAttrByKeyFromRepeatedKV(TimezoneKey, params)
|
|
if err != nil {
|
|
return ""
|
|
}
|
|
return timezone
|
|
}
|
|
|
|
func parseTimeFields(params []*commonpb.KeyValuePair) []string {
|
|
timeFields, err := funcutil.GetAttrByKeyFromRepeatedKV(TimefieldsKey, params)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return strings.FieldsFunc(timeFields, func(r rune) bool {
|
|
return r == ',' || r == ' '
|
|
})
|
|
}
|
|
|
|
func getGroupScorerStr(params []*commonpb.KeyValuePair) string {
|
|
groupScorerStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RankGroupScorer, params)
|
|
if err != nil {
|
|
groupScorerStr = MaxScorer
|
|
}
|
|
return groupScorerStr
|
|
}
|
|
|
|
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {
|
|
ret := &milvuspb.SearchRequest{
|
|
Base: req.GetBase(),
|
|
DbName: req.GetDbName(),
|
|
CollectionName: req.GetCollectionName(),
|
|
PartitionNames: req.GetPartitionNames(),
|
|
OutputFields: req.GetOutputFields(),
|
|
SearchParams: req.GetRankParams(),
|
|
TravelTimestamp: req.GetTravelTimestamp(),
|
|
GuaranteeTimestamp: req.GetGuaranteeTimestamp(),
|
|
Nq: 0,
|
|
NotReturnAllMeta: req.GetNotReturnAllMeta(),
|
|
ConsistencyLevel: req.GetConsistencyLevel(),
|
|
UseDefaultConsistency: req.GetUseDefaultConsistency(),
|
|
SearchByPrimaryKeys: false,
|
|
SubReqs: nil,
|
|
FunctionScore: req.FunctionScore,
|
|
}
|
|
|
|
for _, sub := range req.GetRequests() {
|
|
subReq := &milvuspb.SubSearchRequest{
|
|
Dsl: sub.GetDsl(),
|
|
PlaceholderGroup: sub.GetPlaceholderGroup(),
|
|
DslType: sub.GetDslType(),
|
|
SearchParams: sub.GetSearchParams(),
|
|
Nq: sub.GetNq(),
|
|
ExprTemplateValues: sub.GetExprTemplateValues(),
|
|
}
|
|
ret.SubReqs = append(ret.SubReqs, subReq)
|
|
}
|
|
return ret
|
|
}
|
|
|
|
func getMetricType(toReduceResults []*internalpb.SearchResults) string {
|
|
for _, r := range toReduceResults {
|
|
if m := r.GetMetricType(); m != "" {
|
|
return m
|
|
}
|
|
}
|
|
return ""
|
|
}
|