milvus/internal/proxy/task_query.go
Tianx 2c0c5ef41e
feat: timestamptz expression & index & timezone (#44080)
issue: https://github.com/milvus-io/milvus/issues/27467

>My plan is as follows.
>- [x] M1 Create collection with timestamptz field
>- [x] M2 Insert timestamptz field data
>- [x] M3 Retrieve timestamptz field data
>- [x] M4 Implement handoff
>- [x] M5 Implement compare operator
>- [x] M6 Implement extract operator
 >- [x] M8 Support database/collection level default timezone
>- [x] M7 Support STL-SORT index for datatype timestamptz

---

The third PR of issue: https://github.com/milvus-io/milvus/issues/27467,
which completes M5, M6, M7, M8 described above.

## M8 Default Timezone

We will be able to use alter_collection() and alter_database() in a
future Python SDK release to modify the default timezone at the
collection or database level.

For insert requests, the timezone will be resolved using the following
order of precedence: String Literal-> Collection Default -> Database
Default.
For retrieval requests, the timezone will be resolved in this order:
Query Parameters -> Collection Default -> Database Default.
In both cases, the final fallback timezone is UTC.


## M5: Comparison Operators

We can now use the following expression format to filter on the
timestamptz field:

- `timestamptz_field [+/- INTERVAL 'interval_string'] {comparison_op}
ISO 'iso_string' `

- The interval_string follows the ISO 8601 duration format, for example:
P1Y2M3DT1H2M3S.

- The iso_string follows the ISO 8601 timestamp format, for example:
2025-01-03T00:00:00+08:00.

- Example expressions: "tsz + INTERVAL 'P0D' != ISO
'2025-01-03T00:00:00+08:00'" or "tsz != ISO
'2025-01-03T00:00:00+08:00'".

## M6: Extract

We will be able to extract sepecific time filed by kwargs in a future
Python SDK release.
The key is `time_fields`, and value should be one or more of "year,
month, day, hour, minute, second, microsecond", seperated by comma or
space. Then the result of each record would be an array of int64.



## M7: Indexing Support

Expressions without interval arithmetic can be accelerated using an
STL-SORT index. However, expressions that include interval arithmetic
cannot be indexed. This is because the result of an interval calculation
depends on the specific timestamp value. For example, adding one month
to a date in February results in a different number of added days than
adding one month to a date in March.

--- 

After this PR, the input / output type of timestamptz would be iso
string. Timestampz would be stored as timestamptz data, which is int64_t
finally.

> for more information, see https://en.wikipedia.org/wiki/ISO_8601

---------

Signed-off-by: xtx <xtianx@smail.nju.edu.cn>
2025-09-23 10:24:12 +08:00

934 lines
32 KiB
Go

package proxy
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/cockroachdb/errors"
"github.com/samber/lo"
"go.uber.org/zap"
"google.golang.org/protobuf/proto"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/parser/planparserv2"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/exprutil"
"github.com/milvus-io/milvus/internal/util/reduce"
"github.com/milvus-io/milvus/internal/util/segcore"
typeutil2 "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/pkg/v2/common"
"github.com/milvus-io/milvus/pkg/v2/log"
"github.com/milvus-io/milvus/pkg/v2/metrics"
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
"github.com/milvus-io/milvus/pkg/v2/util/merr"
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
"github.com/milvus-io/milvus/pkg/v2/util/timerecord"
"github.com/milvus-io/milvus/pkg/v2/util/tsoutil"
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
)
const (
WithCache = true
WithoutCache = false
)
const (
RetrieveTaskName = "RetrieveTask"
QueryTaskName = "QueryTask"
)
type queryTask struct {
baseTask
Condition
*internalpb.RetrieveRequest
ctx context.Context
result *milvuspb.QueryResults
request *milvuspb.QueryRequest
mixCoord types.MixCoordClient
ids *schemapb.IDs
collectionName string
queryParams *queryParams
schema *schemaInfo
translatedOutputFields []string
userOutputFields []string
userDynamicFields []string
resultBuf *typeutil.ConcurrentSet[*internalpb.RetrieveResults]
plan *planpb.PlanNode
partitionKeyMode bool
lb LBPolicy
channelsMvcc map[string]Timestamp
fastSkip bool
reQuery bool
allQueryCnt int64
totalRelatedDataSize int64
mustUsePartitionKey bool
storageCost segcore.StorageCost
}
type queryParams struct {
limit int64
offset int64
reduceType reduce.IReduceType
isIterator bool
collectionID int64
timezone string
extractTimeFields []string
}
// translateToOutputFieldIDs translates output fields name to output fields id.
// If no output fields specified, return only pk field
func translateToOutputFieldIDs(outputFields []string, schema *schemapb.CollectionSchema) ([]UniqueID, error) {
outputFieldIDs := make([]UniqueID, 0, len(outputFields)+1)
if len(outputFields) == 0 {
for _, field := range schema.Fields {
if field.IsPrimaryKey {
outputFieldIDs = append(outputFieldIDs, field.FieldID)
}
}
} else {
var pkFieldID UniqueID
for _, field := range schema.Fields {
if field.IsPrimaryKey {
pkFieldID = field.FieldID
}
}
for _, reqField := range outputFields {
var fieldFound bool
for _, field := range schema.Fields {
if reqField == field.Name {
outputFieldIDs = append(outputFieldIDs, field.FieldID)
fieldFound = true
break
}
}
if !fieldFound {
structFieldLoop:
for _, structField := range schema.StructArrayFields {
for _, field := range structField.Fields {
if reqField == field.Name {
outputFieldIDs = append(outputFieldIDs, field.FieldID)
fieldFound = true
break structFieldLoop
}
}
}
}
if !fieldFound {
return nil, fmt.Errorf("field %s not exist", reqField)
}
}
// pk field needs to be in output field list
var pkFound bool
for _, outputField := range outputFieldIDs {
if outputField == pkFieldID {
pkFound = true
break
}
}
if !pkFound {
outputFieldIDs = append(outputFieldIDs, pkFieldID)
}
}
return outputFieldIDs, nil
}
func filterSystemFields(outputFieldIDs []UniqueID) []UniqueID {
filtered := make([]UniqueID, 0, len(outputFieldIDs))
for _, outputFieldID := range outputFieldIDs {
if !common.IsSystemField(outputFieldID) {
filtered = append(filtered, outputFieldID)
}
}
return filtered
}
// parseQueryParams get limit and offset from queryParamsPair, both are optional.
func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, error) {
var (
limit int64
offset int64
reduceStopForBest bool
isIterator bool
err error
collectionID int64
timezone string
extractTimeFields []string
)
reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair)
// if reduce_stop_for_best is provided
if err == nil {
reduceStopForBest, err = strconv.ParseBool(reduceStopForBestStr)
if err != nil {
return nil, merr.WrapErrParameterInvalid("true or false", reduceStopForBestStr,
"value for reduce_stop_for_best is invalid")
}
}
isIteratorStr, err := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, queryParamsPair)
// if reduce_stop_for_best is provided
if err == nil {
isIterator, err = strconv.ParseBool(isIteratorStr)
if err != nil {
return nil, merr.WrapErrParameterInvalid("true or false", isIteratorStr,
"value for iterator field is invalid")
}
}
collectionIdStr, err := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, queryParamsPair)
if err == nil {
collectionID, err = strconv.ParseInt(collectionIdStr, 0, 64)
if err != nil {
return nil, merr.WrapErrParameterInvalid("int value for collection_id", CollectionID,
"value for collection id is invalid")
}
}
reduceType := reduce.IReduceNoOrder
if isIterator {
if reduceStopForBest {
reduceType = reduce.IReduceInOrderForBest
} else {
reduceType = reduce.IReduceInOrder
}
}
limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
// if limit is not provided
if err != nil {
return &queryParams{limit: typeutil.Unlimited, reduceType: reduceType, isIterator: isIterator}, nil
}
limit, err = strconv.ParseInt(limitStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", LimitKey, limitStr)
}
offsetStr, err := funcutil.GetAttrByKeyFromRepeatedKV(OffsetKey, queryParamsPair)
// if offset is provided
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}
}
timezoneStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TimezoneKey, queryParamsPair)
if err == nil {
timezone = timezoneStr
}
extractTimeFieldsStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TimefieldsKey, queryParamsPair)
if err == nil {
extractTimeFields = strings.FieldsFunc(extractTimeFieldsStr, func(r rune) bool {
return r == ',' || r == ' '
})
}
// validate max result window.
if err = validateMaxQueryResultWindow(offset, limit); err != nil {
return nil, fmt.Errorf("invalid max query result window, %w", err)
}
return &queryParams{
limit: limit,
offset: offset,
reduceType: reduceType,
isIterator: isIterator,
collectionID: collectionID,
timezone: timezone,
extractTimeFields: extractTimeFields,
}, nil
}
func matchCountRule(outputs []string) bool {
return len(outputs) == 1 && strings.ToLower(strings.TrimSpace(outputs[0])) == "count(*)"
}
func createCntPlan(expr string, schemaHelper *typeutil.SchemaHelper, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, error) {
if expr == "" {
return &planpb.PlanNode{
Node: &planpb.PlanNode_Query{
Query: &planpb.QueryPlanNode{
Predicates: nil,
IsCount: true,
},
},
}, nil
}
start := time.Now()
plan, err := planparserv2.CreateRetrievePlan(schemaHelper, expr, exprTemplateValues)
if err != nil {
metrics.ProxyParseExpressionLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "query", metrics.FailLabel).Observe(float64(time.Since(start).Milliseconds()))
return nil, merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err))
}
metrics.ProxyParseExpressionLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "query", metrics.SuccessLabel).Observe(float64(time.Since(start).Milliseconds()))
plan.Node.(*planpb.PlanNode_Query).Query.IsCount = true
return plan, nil
}
func (t *queryTask) createPlan(ctx context.Context) error {
return t.createPlanArgs(ctx, &planparserv2.ParserVisitorArgs{})
}
func (t *queryTask) createPlanArgs(ctx context.Context, visitorArgs *planparserv2.ParserVisitorArgs) error {
schema := t.schema
cntMatch := matchCountRule(t.request.GetOutputFields())
if cntMatch {
var err error
t.plan, err = createCntPlan(t.request.GetExpr(), schema.schemaHelper, t.request.GetExprTemplateValues())
t.userOutputFields = []string{"count(*)"}
return err
}
var err error
if t.plan == nil {
start := time.Now()
t.plan, err = planparserv2.CreateRetrievePlanArgs(schema.schemaHelper, t.request.Expr, t.request.GetExprTemplateValues(), visitorArgs)
if err != nil {
metrics.ProxyParseExpressionLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "query", metrics.FailLabel).Observe(float64(time.Since(start).Milliseconds()))
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", err))
}
metrics.ProxyParseExpressionLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "query", metrics.SuccessLabel).Observe(float64(time.Since(start).Milliseconds()))
}
t.translatedOutputFields, t.userOutputFields, t.userDynamicFields, _, err = translateOutputFields(t.request.OutputFields, t.schema, false)
if err != nil {
return err
}
outputFieldIDs, err := translateToOutputFieldIDs(t.translatedOutputFields, schema.CollectionSchema)
if err != nil {
return err
}
outputFieldIDs = append(outputFieldIDs, common.TimeStampField)
t.RetrieveRequest.OutputFieldsId = outputFieldIDs
t.plan.OutputFieldIds = outputFieldIDs
t.plan.DynamicFields = t.userDynamicFields
log.Ctx(ctx).Debug("translate output fields to field ids",
zap.Int64s("OutputFieldsID", t.OutputFieldsId),
zap.String("requestType", "query"))
return nil
}
func (t *queryTask) CanSkipAllocTimestamp() bool {
var consistencyLevel commonpb.ConsistencyLevel
useDefaultConsistency := t.request.GetUseDefaultConsistency()
if !useDefaultConsistency {
// legacy SDK & resultful behavior
if t.request.GetConsistencyLevel() == commonpb.ConsistencyLevel_Strong && t.request.GetGuaranteeTimestamp() > 0 {
return true
}
consistencyLevel = t.request.GetConsistencyLevel()
} else {
collID, err := globalMetaCache.GetCollectionID(context.Background(), t.request.GetDbName(), t.request.GetCollectionName())
if err != nil { // err is not nil if collection not exists
log.Ctx(t.ctx).Warn("query task get collectionID failed, can't skip alloc timestamp",
zap.String("collectionName", t.request.GetCollectionName()), zap.Error(err))
return false
}
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(context.Background(), t.request.GetDbName(), t.request.GetCollectionName(), collID)
if err2 != nil {
log.Ctx(t.ctx).Warn("query task get collection info failed, can't skip alloc timestamp",
zap.String("collectionName", t.request.GetCollectionName()), zap.Error(err))
return false
}
consistencyLevel = collectionInfo.consistencyLevel
}
return consistencyLevel != commonpb.ConsistencyLevel_Strong
}
func (t *queryTask) PreExecute(ctx context.Context) error {
t.Base.MsgType = commonpb.MsgType_Retrieve
t.Base.SourceID = paramtable.GetNodeID()
collectionName := t.request.CollectionName
t.collectionName = collectionName
log := log.Ctx(ctx).With(zap.String("collectionName", collectionName),
zap.Strings("partitionNames", t.request.GetPartitionNames()),
zap.String("requestType", "query"))
if err := validateCollectionName(collectionName); err != nil {
log.Warn("Invalid collectionName.")
return err
}
log.Debug("Validate collectionName.")
collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("Failed to get collection id.", zap.String("collectionName", collectionName), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
t.CollectionID = collID
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err != nil {
log.Warn("Failed to get collection info.", zap.String("collectionName", collectionName),
zap.Int64("collectionID", t.CollectionID), zap.Error(err))
return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound)
}
log.Debug("Get collection ID by name", zap.Int64("collectionID", t.CollectionID))
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil {
log.Warn("check partition key mode failed", zap.Int64("collectionID", t.CollectionID), zap.Error(err))
return err
}
if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("not support manually specifying the partition names if partition key mode is used"))
}
if t.mustUsePartitionKey && !t.partitionKeyMode {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("must use partition key in the query request " +
"because the mustUsePartitionKey config is true"))
}
for _, tag := range t.request.PartitionNames {
if err := validatePartitionTag(tag, false); err != nil {
log.Warn("invalid partition name", zap.String("partition name", tag))
return err
}
}
log.Debug("Validate partition names.")
// fetch search_growing from query param
if t.RetrieveRequest.IgnoreGrowing, err = isIgnoreGrowing(t.request.GetQueryParams()); err != nil {
return err
}
queryParams, err := parseQueryParams(t.request.GetQueryParams())
if err != nil {
return err
}
if queryParams.collectionID > 0 && queryParams.collectionID != t.GetCollectionID() {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("Input collection id is not consistent to collectionID in the context," +
"alias or database may have changed"))
}
if queryParams.reduceType == reduce.IReduceInOrderForBest {
t.RetrieveRequest.ReduceStopForBest = true
}
t.RetrieveRequest.ReduceType = int32(queryParams.reduceType)
t.queryParams = queryParams
t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
schema, err := globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), t.collectionName)
if err != nil {
log.Warn("get collection schema failed", zap.Error(err))
return err
}
t.schema = schema
if t.ids != nil {
pkField := ""
for _, field := range schema.Fields {
if field.IsPrimaryKey {
pkField = field.Name
}
}
t.request.Expr = IDs2Expr(pkField, t.ids)
}
_, colTimezone := getColTimezone(colInfo)
timezonePreference := []string{t.queryParams.timezone, colTimezone}
if err := t.createPlanArgs(ctx, &planparserv2.ParserVisitorArgs{TimezonePreference: timezonePreference}); err != nil {
return err
}
t.plan.Node.(*planpb.PlanNode_Query).Query.Limit = t.RetrieveRequest.Limit
if planparserv2.IsAlwaysTruePlan(t.plan) && t.RetrieveRequest.Limit == typeutil.Unlimited {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("empty expression should be used with limit"))
}
// convert partition names only when requery is false
if !t.reQuery {
partitionNames := t.request.GetPartitionNames()
if t.partitionKeyMode {
expr, err := exprutil.ParseExprFromPlan(t.plan)
if err != nil {
return err
}
partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey)
hashedPartitionNames, err := assignPartitionKeys(ctx, t.request.GetDbName(), t.request.CollectionName, partitionKeys)
if err != nil {
return err
}
partitionNames = append(partitionNames, hashedPartitionNames...)
}
t.RetrieveRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.request.CollectionName, partitionNames)
if err != nil {
return err
}
}
// count with pagination
if t.plan.GetQuery().GetIsCount() && t.queryParams.limit != typeutil.Unlimited {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("count entities with pagination is not allowed"))
}
t.RetrieveRequest.IsCount = t.plan.GetQuery().GetIsCount()
t.RetrieveRequest.SerializedExprPlan, err = proto.Marshal(t.plan)
if err != nil {
return err
}
// Set username for this query request,
if username, _ := GetCurUserFromContext(ctx); username != "" {
t.RetrieveRequest.Username = username
}
collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID)
if err2 != nil {
log.Warn("Proxy::queryTask::PreExecute failed to GetCollectionInfo from cache",
zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID),
zap.Error(err2))
return err2
}
guaranteeTs := t.request.GetGuaranteeTimestamp()
var consistencyLevel commonpb.ConsistencyLevel
useDefaultConsistency := t.request.GetUseDefaultConsistency()
t.RetrieveRequest.ConsistencyLevel = t.request.GetConsistencyLevel()
if useDefaultConsistency {
consistencyLevel = collectionInfo.consistencyLevel
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
} else {
consistencyLevel = t.request.GetConsistencyLevel()
// Compatibility logic, parse guarantee timestamp
if consistencyLevel == 0 && guaranteeTs > 0 {
guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs())
} else {
// parse from guarantee timestamp and user input consistency level
guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel)
}
}
// use collection schema updated timestamp if it's greater than calculate guarantee timestamp
// this make query view updated happens before new read request happens
// see also schema change design
if collectionInfo.updateTimestamp > guaranteeTs {
guaranteeTs = collectionInfo.updateTimestamp
}
t.GuaranteeTimestamp = guaranteeTs
// need modify mvccTs and guaranteeTs for iterator specially
if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() > 0 {
t.MvccTimestamp = t.request.GetGuaranteeTimestamp()
t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp()
}
t.RetrieveRequest.IsIterator = queryParams.isIterator
if collectionInfo.collectionTTL != 0 {
physicalTime := tsoutil.PhysicalTime(t.GetBase().GetTimestamp())
expireTime := physicalTime.Add(-time.Duration(collectionInfo.collectionTTL))
t.CollectionTtlTimestamps = tsoutil.ComposeTSByTime(expireTime, 0)
// preventing overflow, abort
if t.CollectionTtlTimestamps > t.GetBase().GetTimestamp() {
return merr.WrapErrServiceInternal(fmt.Sprintf("ttl timestamp overflow, base timestamp: %d, ttl duration %v", t.GetBase().GetTimestamp(), collectionInfo.collectionTTL))
}
}
deadline, ok := t.TraceCtx().Deadline()
if ok {
t.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
}
t.DbID = 0 // TODO
log.Debug("Query PreExecute done.",
zap.Uint64("guarantee_ts", guaranteeTs),
zap.Uint64("mvcc_ts", t.GetMvccTimestamp()),
zap.Uint64("timeout_ts", t.GetTimeoutTimestamp()),
zap.Uint64("collection_ttl_timestamps", t.CollectionTtlTimestamps))
return nil
}
func (t *queryTask) Execute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute query %d", t.ID()))
defer tr.CtxElapse(ctx, "done")
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.String("requestType", "query"))
t.resultBuf = typeutil.NewConcurrentSet[*internalpb.RetrieveResults]()
err := t.lb.Execute(ctx, CollectionWorkLoad{
db: t.request.GetDbName(),
collectionID: t.CollectionID,
collectionName: t.collectionName,
nq: 1,
exec: t.queryShard,
})
if err != nil {
log.Warn("fail to execute query", zap.Error(err))
return errors.Wrap(err, "failed to query")
}
log.Debug("Query Execute done.")
return nil
}
// FieldsData in results are flattened, so we need to reconstruct the struct fields
func reconstructStructFieldData(results *milvuspb.QueryResults, schema *schemapb.CollectionSchema) {
if len(results.OutputFields) == 1 && results.OutputFields[0] == "count(*)" {
return
}
if len(schema.StructArrayFields) == 0 {
return
}
regularFieldIDs := make(map[int64]interface{})
subFieldToStructMap := make(map[int64]int64)
groupedStructFields := make(map[int64][]*schemapb.FieldData)
structFieldNames := make(map[int64]string)
reconstructedOutputFields := make([]string, 0, len(results.FieldsData))
// record all regular field IDs
for _, field := range schema.Fields {
regularFieldIDs[field.GetFieldID()] = nil
}
// build the mapping from sub-field ID to struct field ID
for _, structField := range schema.StructArrayFields {
for _, subField := range structField.GetFields() {
subFieldToStructMap[subField.GetFieldID()] = structField.GetFieldID()
}
structFieldNames[structField.GetFieldID()] = structField.GetName()
}
fieldsData := make([]*schemapb.FieldData, 0, len(results.FieldsData))
for _, field := range results.FieldsData {
fieldID := field.GetFieldId()
if _, ok := regularFieldIDs[fieldID]; ok {
fieldsData = append(fieldsData, field)
reconstructedOutputFields = append(reconstructedOutputFields, field.GetFieldName())
} else {
structFieldID := subFieldToStructMap[fieldID]
groupedStructFields[structFieldID] = append(groupedStructFields[structFieldID], field)
}
}
for structFieldID, fields := range groupedStructFields {
fieldData := &schemapb.FieldData{
FieldName: structFieldNames[structFieldID],
FieldId: structFieldID,
Type: schemapb.DataType_ArrayOfStruct,
Field: &schemapb.FieldData_StructArrays{StructArrays: &schemapb.StructArrayField{Fields: fields}},
}
fieldsData = append(fieldsData, fieldData)
reconstructedOutputFields = append(reconstructedOutputFields, structFieldNames[structFieldID])
}
results.FieldsData = fieldsData
results.OutputFields = reconstructedOutputFields
}
func (t *queryTask) PostExecute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder("queryTask PostExecute")
defer func() {
tr.CtxElapse(ctx, "done")
}()
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.String("requestType", "query"))
var err error
toReduceResults := make([]*internalpb.RetrieveResults, 0)
t.allQueryCnt = 0
t.totalRelatedDataSize = 0
t.storageCost = segcore.StorageCost{}
select {
case <-t.TraceCtx().Done():
log.Warn("proxy", zap.Int64("Query: wait to finish failed, timeout!, msgID:", t.ID()))
return nil
default:
log.Debug("all queries are finished or canceled")
t.resultBuf.Range(func(res *internalpb.RetrieveResults) bool {
toReduceResults = append(toReduceResults, res)
t.allQueryCnt += res.GetAllRetrieveCount()
t.storageCost.ScannedRemoteBytes += res.GetScannedRemoteBytes()
t.storageCost.ScannedTotalBytes += res.GetScannedTotalBytes()
t.totalRelatedDataSize += res.GetCostAggregation().GetTotalRelatedDataSize()
log.Debug("proxy receives one query result", zap.Int64("sourceID", res.GetBase().GetSourceID()))
return true
})
}
metrics.ProxyDecodeResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(0.0)
tr.CtxRecord(ctx, "reduceResultStart")
reducer := createMilvusReducer(ctx, t.queryParams, t.RetrieveRequest, t.schema.CollectionSchema, t.plan, t.collectionName)
t.result, err = reducer.Reduce(toReduceResults)
if err != nil {
log.Warn("fail to reduce query result", zap.Error(err))
return err
}
t.result.OutputFields = t.userOutputFields
if !t.reQuery {
reconstructStructFieldDataForQuery(t.result, t.schema.CollectionSchema)
}
primaryFieldSchema, err := t.schema.GetPkField()
if err != nil {
log.Warn("failed to get primary field schema", zap.Error(err))
return err
}
t.result.PrimaryFieldName = primaryFieldSchema.GetName()
metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Observe(float64(tr.RecordSpan().Milliseconds()))
if t.queryParams.isIterator && t.request.GetGuaranteeTimestamp() == 0 {
// first page for iteration, need to set up sessionTs for iterator
t.result.SessionTs = getMaxMvccTsFromChannels(t.channelsMvcc, t.BeginTs())
}
// Translate timestamp to ISO string
collName := t.request.GetCollectionName()
dbName := t.request.GetDbName()
collID, err := globalMetaCache.GetCollectionID(context.Background(), dbName, collName)
if err != nil {
log.Warn("fail to get collection id", zap.Error(err))
return err
}
colInfo, err := globalMetaCache.GetCollectionInfo(ctx, dbName, collName, collID)
if err != nil {
log.Warn("fail to get collection info", zap.Error(err))
return err
}
_, colTimezone := getColTimezone(colInfo)
if !t.reQuery {
if len(t.queryParams.extractTimeFields) > 0 {
log.Debug("extracting fields for timestamptz", zap.Strings("fields", t.queryParams.extractTimeFields))
err = extractFieldsFromResults(t.result.GetFieldsData(), []string{t.queryParams.timezone, colTimezone}, t.queryParams.extractTimeFields)
if err != nil {
log.Warn("fail to extract fields for timestamptz", zap.Error(err))
return err
}
} else {
log.Debug("translate timestamp to ISO string", zap.String("user define timezone", t.queryParams.timezone))
err = timestamptzUTC2IsoStr(t.result.GetFieldsData(), t.queryParams.timezone, colTimezone)
if err != nil {
log.Warn("fail to translate timestamp", zap.Error(err))
return err
}
}
}
log.Debug("Query PostExecute done")
return nil
}
func (t *queryTask) IsSubTask() bool {
return t.reQuery
}
func (t *queryTask) queryShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error {
needOverrideMvcc := false
mvccTs := t.MvccTimestamp
if len(t.channelsMvcc) > 0 {
mvccTs, needOverrideMvcc = t.channelsMvcc[channel]
// In fast mode, if there is no corresponding channel in channelsMvcc, quickly skip this query.
if !needOverrideMvcc && t.fastSkip {
return nil
}
}
retrieveReq := typeutil.Clone(t.RetrieveRequest)
retrieveReq.GetBase().TargetID = nodeID
if needOverrideMvcc && mvccTs > 0 {
retrieveReq.MvccTimestamp = mvccTs
retrieveReq.GuaranteeTimestamp = mvccTs
}
retrieveReq.ConsistencyLevel = t.ConsistencyLevel
req := &querypb.QueryRequest{
Req: retrieveReq,
DmlChannels: []string{channel},
Scope: querypb.DataScope_All,
}
log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()),
zap.Int64s("partitionIDs", t.GetPartitionIDs()),
zap.Int64("nodeID", nodeID),
zap.String("channel", channel))
result, err := qn.Query(ctx, req)
if err != nil {
log.Warn("QueryNode query return error", zap.Error(err))
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return err
}
if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader {
log.Warn("QueryNode is not shardLeader")
globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName)
return errInvalidShardLeaders
}
if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {
log.Warn("QueryNode query result error", zap.Any("errorCode", result.GetStatus().GetErrorCode()), zap.String("reason", result.GetStatus().GetReason()))
return errors.Wrapf(merr.Error(result.GetStatus()), "fail to Query on QueryNode %d", nodeID)
}
log.Debug("get query result")
t.resultBuf.Insert(result)
t.lb.UpdateCostMetrics(nodeID, result.CostAggregation)
return nil
}
// IDs2Expr converts ids slices to bool expresion with specified field name
func IDs2Expr(fieldName string, ids *schemapb.IDs) string {
var idsStr string
switch ids.GetIdField().(type) {
case *schemapb.IDs_IntId:
idsStr = strings.Trim(strings.Join(strings.Fields(fmt.Sprint(ids.GetIntId().GetData())), ", "), "[]")
case *schemapb.IDs_StrId:
strs := lo.Map(ids.GetStrId().GetData(), func(str string, _ int) string {
return fmt.Sprintf("\"%s\"", str)
})
idsStr = strings.Trim(strings.Join(strs, ", "), "[]")
}
return fieldName + " in [ " + idsStr + " ]"
}
func reduceRetrieveResults(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams) (*milvuspb.QueryResults, error) {
log.Ctx(ctx).Debug("reduceInternalRetrieveResults", zap.Int("len(retrieveResults)", len(retrieveResults)))
var (
ret = &milvuspb.QueryResults{}
loopEnd int
)
validRetrieveResults := []*internalpb.RetrieveResults{}
for _, r := range retrieveResults {
size := typeutil.GetSizeOfIDs(r.GetIds())
if r == nil || len(r.GetFieldsData()) == 0 || size == 0 {
continue
}
validRetrieveResults = append(validRetrieveResults, r)
loopEnd += size
}
if len(validRetrieveResults) == 0 {
return ret, nil
}
cursors := make([]int64, len(validRetrieveResults))
if queryParams != nil && queryParams.limit != typeutil.Unlimited {
// IReduceInOrderForBest will try to get as many results as possible
// so loopEnd in this case will be set to the sum of all results' size
// to get as many qualified results as possible
if reduce.ShouldUseInputLimit(queryParams.reduceType) {
loopEnd = int(queryParams.limit)
}
}
// handle offset
if queryParams != nil && queryParams.offset > 0 {
for i := int64(0); i < queryParams.offset; i++ {
sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
return ret, nil
}
cursors[sel]++
}
}
ret.FieldsData = typeutil.PrepareResultFieldData(validRetrieveResults[0].GetFieldsData(), int64(loopEnd))
var retSize int64
maxOutputSize := paramtable.Get().QuotaConfig.MaxOutputSize.GetAsInt64()
for j := 0; j < loopEnd; j++ {
sel, drainOneResult := typeutil.SelectMinPK(validRetrieveResults, cursors)
if sel == -1 || (reduce.ShouldStopWhenDrained(queryParams.reduceType) && drainOneResult) {
break
}
retSize += typeutil.AppendFieldData(ret.FieldsData, validRetrieveResults[sel].GetFieldsData(), cursors[sel])
// limit retrieve result to avoid oom
if retSize > maxOutputSize {
return nil, fmt.Errorf("query results exceed the maxOutputSize Limit %d", maxOutputSize)
}
cursors[sel]++
}
return ret, nil
}
func reduceRetrieveResultsAndFillIfEmpty(ctx context.Context, retrieveResults []*internalpb.RetrieveResults, queryParams *queryParams, outputFieldsID []int64, schema *schemapb.CollectionSchema) (*milvuspb.QueryResults, error) {
result, err := reduceRetrieveResults(ctx, retrieveResults, queryParams)
if err != nil {
return nil, err
}
// filter system fields.
filtered := filterSystemFields(outputFieldsID)
if err := typeutil2.FillRetrieveResultIfEmpty(typeutil2.NewMilvusResult(result), filtered, schema); err != nil {
return nil, fmt.Errorf("failed to fill retrieve results: %s", err.Error())
}
return result, nil
}
func (t *queryTask) TraceCtx() context.Context {
return t.ctx
}
func (t *queryTask) ID() UniqueID {
return t.Base.MsgID
}
func (t *queryTask) SetID(uid UniqueID) {
t.Base.MsgID = uid
}
func (t *queryTask) Name() string {
return RetrieveTaskName
}
func (t *queryTask) Type() commonpb.MsgType {
return t.Base.MsgType
}
func (t *queryTask) BeginTs() Timestamp {
return t.Base.Timestamp
}
func (t *queryTask) EndTs() Timestamp {
return t.Base.Timestamp
}
func (t *queryTask) SetTs(ts Timestamp) {
if t.reQuery && t.Base.Timestamp != 0 {
return
}
t.Base.Timestamp = ts
}
func (t *queryTask) OnEnqueue() error {
if t.Base == nil {
t.Base = commonpbutil.NewMsgBase()
}
t.Base.MsgType = commonpb.MsgType_Retrieve
t.Base.SourceID = paramtable.GetNodeID()
return nil
}