package proxy import ( "context" "fmt" "math" "strconv" "strings" "time" "github.com/cockroachdb/errors" "github.com/samber/lo" "go.opentelemetry.io/otel" "go.uber.org/zap" "google.golang.org/protobuf/proto" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/json" "github.com/milvus-io/milvus/internal/parser/planparserv2" "github.com/milvus-io/milvus/internal/proxy/accesslog" "github.com/milvus-io/milvus/internal/proxy/shardclient" "github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/util/exprutil" "github.com/milvus-io/milvus/internal/util/function/embedding" "github.com/milvus-io/milvus/internal/util/function/models" "github.com/milvus-io/milvus/internal/util/function/rerank" "github.com/milvus-io/milvus/internal/util/segcore" "github.com/milvus-io/milvus/pkg/v2/common" "github.com/milvus-io/milvus/pkg/v2/log" "github.com/milvus-io/milvus/pkg/v2/metrics" "github.com/milvus-io/milvus/pkg/v2/proto/internalpb" "github.com/milvus-io/milvus/pkg/v2/proto/planpb" "github.com/milvus-io/milvus/pkg/v2/proto/querypb" "github.com/milvus-io/milvus/pkg/v2/util/commonpbutil" "github.com/milvus-io/milvus/pkg/v2/util/funcutil" "github.com/milvus-io/milvus/pkg/v2/util/merr" "github.com/milvus-io/milvus/pkg/v2/util/metric" "github.com/milvus-io/milvus/pkg/v2/util/paramtable" "github.com/milvus-io/milvus/pkg/v2/util/timerecord" "github.com/milvus-io/milvus/pkg/v2/util/timestamptz" "github.com/milvus-io/milvus/pkg/v2/util/tsoutil" "github.com/milvus-io/milvus/pkg/v2/util/typeutil" ) const ( SearchTaskName = "SearchTask" SearchLevelKey = "level" // requeryThreshold is the estimated threshold for the size of the search results. // If the number of estimated search results exceeds this threshold, // a second query request will be initiated to retrieve output fields data. // In this case, the first search will not return any output field from QueryNodes. requeryThreshold = 0.5 * 1024 * 1024 radiusKey = "radius" rangeFilterKey = "range_filter" ) // type requery func(span trace.Span, ids *schemapb.IDs, outputFields []string) (*milvuspb.QueryResults, error) type searchTask struct { baseTask Condition ctx context.Context *internalpb.SearchRequest result *milvuspb.SearchResults request *milvuspb.SearchRequest tr *timerecord.TimeRecorder collectionName string schema *schemaInfo needRequery bool partitionKeyMode bool enableMaterializedView bool mustUsePartitionKey bool resultSizeInsufficient bool isTopkReduce bool isRecallEvaluation bool translatedOutputFields []string userOutputFields []string userDynamicFields []string highlightTasks []*highlightTask resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults] partitionIDsSet *typeutil.ConcurrentSet[UniqueID] mixCoord types.MixCoordClient node types.ProxyComponent lb shardclient.LBPolicy shardClientMgr shardclient.ShardClientMgr queryChannelsTs map[string]Timestamp queryInfos []*planpb.QueryInfo relatedDataSize int64 // New reranker functions functionScore *rerank.FunctionScore rankParams *rankParams resolvedTimezoneStr string isIterator bool // we always remove pk field from output fields, as search result already contains pk field. // if the user explicitly set pk field in output fields, we add it back to the result. userRequestedPkFieldExplicitly bool storageCost segcore.StorageCost } func (t *searchTask) CanSkipAllocTimestamp() bool { var consistencyLevel commonpb.ConsistencyLevel useDefaultConsistency := t.request.GetUseDefaultConsistency() if !useDefaultConsistency { // legacy SDK & restful behavior if t.request.GetConsistencyLevel() == commonpb.ConsistencyLevel_Strong && t.request.GetGuaranteeTimestamp() > 0 { return true } consistencyLevel = t.request.GetConsistencyLevel() } else { collID, err := globalMetaCache.GetCollectionID(context.Background(), t.request.GetDbName(), t.request.GetCollectionName()) if err != nil { // err is not nil if collection not exists log.Ctx(t.ctx).Warn("search task get collectionID failed, can't skip alloc timestamp", zap.String("collectionName", t.request.GetCollectionName()), zap.Error(err)) return false } collectionInfo, err2 := globalMetaCache.GetCollectionInfo(context.Background(), t.request.GetDbName(), t.request.GetCollectionName(), collID) if err2 != nil { log.Ctx(t.ctx).Warn("search task get collection info failed, can't skip alloc timestamp", zap.String("collectionName", t.request.GetCollectionName()), zap.Error(err)) return false } consistencyLevel = collectionInfo.consistencyLevel } return consistencyLevel != commonpb.ConsistencyLevel_Strong } func (t *searchTask) PreExecute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PreExecute") defer sp.End() t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0 t.Base.MsgType = commonpb.MsgType_Search t.Base.SourceID = paramtable.GetNodeID() collectionName := t.request.CollectionName t.collectionName = collectionName collID, err := globalMetaCache.GetCollectionID(ctx, t.request.GetDbName(), collectionName) if err != nil { // err is not nil if collection not exists return merr.WrapErrAsInputErrorWhen(err, merr.ErrCollectionNotFound, merr.ErrDatabaseNotFound) } t.SearchRequest.DbID = 0 // todo t.SearchRequest.CollectionID = collID log := log.Ctx(ctx).With(zap.Int64("collID", collID), zap.String("collName", collectionName)) t.schema, err = globalMetaCache.GetCollectionSchema(ctx, t.request.GetDbName(), collectionName) if err != nil { log.Warn("get collection schema failed", zap.Error(err)) return err } t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName) if err != nil { log.Warn("is partition key mode failed", zap.Error(err)) return err } if t.partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { return errors.New("not support manually specifying the partition names if partition key mode is used") } if t.mustUsePartitionKey && !t.partitionKeyMode { return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("must use partition key in the search request " + "because the mustUsePartitionKey config is true")) } if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 { // translate partition name to partition ids. Use regex-pattern to match partition name. t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames()) if err != nil { log.Warn("failed to get partition ids", zap.Error(err)) return err } } err = common.CheckNamespace(t.schema.CollectionSchema, t.request.Namespace) if err != nil { return err } t.translatedOutputFields, t.userOutputFields, t.userDynamicFields, t.userRequestedPkFieldExplicitly, err = translateOutputFields(t.request.OutputFields, t.schema, true) if err != nil { log.Warn("translate output fields failed", zap.Error(err), zap.Any("schema", t.schema)) return err } log.Debug("translate output fields", zap.Strings("output fields", t.translatedOutputFields)) if t.SearchRequest.GetIsAdvanced() { if len(t.request.GetSubReqs()) > defaultMaxSearchRequest { return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest)) } } nq, err := t.checkNq(ctx) if err != nil { log.Info("failed to check nq", zap.Error(err)) return err } t.SearchRequest.Nq = nq if t.SearchRequest.IgnoreGrowing, err = isIgnoreGrowing(t.request.SearchParams); err != nil { return err } outputFieldIDs, err := getOutputFieldIDs(t.schema, t.translatedOutputFields) if err != nil { log.Info("fail to get output field ids", zap.Error(err)) return err } t.SearchRequest.OutputFieldsId = outputFieldIDs // Currently, we get vectors by requery. Once we support getting vectors from search, // searches with small result size could no longer need requery. if t.SearchRequest.GetIsAdvanced() { err = t.initAdvancedSearchRequest(ctx) } else { err = t.initSearchRequest(ctx) } if err != nil { log.Debug("init search request failed", zap.Error(err)) return err } collectionInfo, err2 := globalMetaCache.GetCollectionInfo(ctx, t.request.GetDbName(), collectionName, t.CollectionID) if err2 != nil { log.Warn("Proxy::searchTask::PreExecute failed to GetCollectionInfo from cache", zap.String("collectionName", collectionName), zap.Int64("collectionID", t.CollectionID), zap.Error(err2)) return err2 } guaranteeTs := t.request.GetGuaranteeTimestamp() var consistencyLevel commonpb.ConsistencyLevel useDefaultConsistency := t.request.GetUseDefaultConsistency() if useDefaultConsistency { consistencyLevel = collectionInfo.consistencyLevel guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel) } else { consistencyLevel = t.request.GetConsistencyLevel() // Compatibility logic, parse guarantee timestamp if consistencyLevel == 0 && guaranteeTs > 0 { guaranteeTs = parseGuaranteeTs(guaranteeTs, t.BeginTs()) } else { // parse from guarantee timestamp and user input consistency level guaranteeTs = parseGuaranteeTsFromConsistency(guaranteeTs, t.BeginTs(), consistencyLevel) } } // update actual consistency level accesslog.SetActualConsistencyLevel(ctx, consistencyLevel) // use collection schema updated timestamp if it's greater than calculate guarantee timestamp // this make query view updated happens before new read request happens // see also schema change design if collectionInfo.updateTimestamp > guaranteeTs { guaranteeTs = collectionInfo.updateTimestamp } t.SearchRequest.GuaranteeTimestamp = guaranteeTs t.SearchRequest.ConsistencyLevel = consistencyLevel if t.isIterator && t.request.GetGuaranteeTimestamp() > 0 { t.MvccTimestamp = t.request.GetGuaranteeTimestamp() t.GuaranteeTimestamp = t.request.GetGuaranteeTimestamp() } t.SearchRequest.IsIterator = t.isIterator if deadline, ok := t.TraceCtx().Deadline(); ok { t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0) } // Set username of this search request for feature like task scheduling. if username, _ := GetCurUserFromContext(ctx); username != "" { t.SearchRequest.Username = username } if collectionInfo.collectionTTL != 0 { physicalTime := tsoutil.PhysicalTime(t.GetBase().GetTimestamp()) expireTime := physicalTime.Add(-time.Duration(collectionInfo.collectionTTL)) t.CollectionTtlTimestamps = tsoutil.ComposeTSByTime(expireTime, 0) // preventing overflow, abort if t.CollectionTtlTimestamps > t.GetBase().GetTimestamp() { return merr.WrapErrServiceInternal(fmt.Sprintf("ttl timestamp overflow, base timestamp: %d, ttl duration %v", t.GetBase().GetTimestamp(), collectionInfo.collectionTTL)) } } timezone, exist := funcutil.TryGetAttrByKeyFromRepeatedKV(common.TimezoneKey, t.request.SearchParams) if exist { if !timestamptz.IsTimezoneValid(timezone) { log.Info("get invalid timezone from request", zap.String("timezone", timezone)) return merr.WrapErrParameterInvalidMsg("unknown or invalid IANA Time Zone ID: %s", timezone) } log.Debug("determine timezone from request", zap.String("user defined timezone", timezone)) } else { timezone = getColTimezone(collectionInfo) log.Debug("determine timezone from collection", zap.Any("collection timezone", timezone)) } t.resolvedTimezoneStr = timezone t.resultBuf = typeutil.NewConcurrentSet[*internalpb.SearchResults]() if err = ValidateTask(t); err != nil { return err } log.Debug("search PreExecute done.", zap.Uint64("guarantee_ts", guaranteeTs), zap.Bool("use_default_consistency", useDefaultConsistency), zap.Any("consistency level", consistencyLevel), zap.Uint64("timeout_ts", t.SearchRequest.GetTimeoutTimestamp()), zap.Uint64("collection_ttl_timestamps", t.CollectionTtlTimestamps)) return nil } func (t *searchTask) checkNq(ctx context.Context) (int64, error) { var nq int64 if t.SearchRequest.GetIsAdvanced() { // In the context of Advanced Search, it is essential to verify that the number of vectors // for each individual search, denoted as nq, remains consistent. nq = t.request.GetNq() for _, req := range t.request.GetSubReqs() { subNq, err := getNqFromSubSearch(req) if err != nil { return 0, err } req.Nq = subNq if nq == 0 { nq = subNq continue } if subNq != nq { err = merr.WrapErrParameterInvalid(nq, subNq, "sub search request nq should be the same") return 0, err } } t.request.Nq = nq } else { var err error nq, err = getNq(t.request) if err != nil { return 0, err } t.request.Nq = nq } // Check if nq is valid: // https://milvus.io/docs/limitations.md if err := validateNQLimit(nq); err != nil { return 0, fmt.Errorf("%s [%d] is invalid, %w", NQKey, nq, err) } return nq, nil } func setQueryInfoIfMvEnable(queryInfo *planpb.QueryInfo, t *searchTask, plan *planpb.PlanNode) error { if t.enableMaterializedView { partitionKeyFieldSchema, err := typeutil.GetPartitionKeyFieldSchema(t.schema.CollectionSchema) if err != nil { log.Ctx(t.ctx).Warn("failed to get partition key field schema", zap.Error(err)) return err } if typeutil.IsFieldDataTypeSupportMaterializedView(partitionKeyFieldSchema) { collInfo, colErr := globalMetaCache.GetCollectionInfo(t.ctx, t.request.GetDbName(), t.collectionName, t.CollectionID) if colErr != nil { log.Ctx(t.ctx).Warn("failed to get collection info", zap.Error(colErr)) return err } if collInfo.partitionKeyIsolation { expr, err := exprutil.ParseExprFromPlan(plan) if err != nil { log.Ctx(t.ctx).Warn("failed to parse expr from plan during MV", zap.Error(err)) return err } err = exprutil.ValidatePartitionKeyIsolation(expr) if err != nil { return err } // force set hints to disable queryInfo.Hints = "disable" } queryInfo.MaterializedViewInvolved = true } else { return errors.New("partition key field data type is not supported in materialized view") } } return nil } func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init advanced search request") defer sp.End() t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]() log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName)) var err error // TODO: Use function score uniformly to implement related logic if t.request.FunctionScore != nil { if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore, &models.ModelExtraInfo{ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), DBName: t.request.GetDbName()}); err != nil { log.Warn("Failed to create function score", zap.Error(err)) return err } } else { if t.functionScore, err = rerank.NewFunctionScoreWithlegacy(t.schema.CollectionSchema, t.request.GetSearchParams()); err != nil { log.Warn("Failed to create function by legacy info", zap.Error(err)) return err } } allFields := typeutil.GetAllFieldSchemas(t.schema.CollectionSchema) vectorOutputFields := lo.Filter(allFields, func(field *schemapb.FieldSchema, _ int) bool { return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType()) }) if t.rankParams, err = parseRankParams(t.request.GetSearchParams(), t.schema.CollectionSchema); err != nil { log.Error("parseRankParams failed", zap.Error(err)) return err } switch strings.ToLower(paramtable.Get().CommonCfg.HybridSearchRequeryPolicy.GetValue()) { case "always": t.needRequery = true case "outputvector": // hybrid group by not support non-requery due to pk-group by field binding not guaranteed t.needRequery = len(vectorOutputFields) > 0 || t.rankParams.GetGroupByFieldId() >= 0 case "outputfields": fallthrough default: t.needRequery = len(t.request.GetOutputFields()) > 0 } t.needRequery = t.needRequery || len(t.functionScore.GetAllInputFieldNames()) > 0 if !t.functionScore.IsSupportGroup() && t.rankParams.GetGroupByFieldId() >= 0 { return merr.WrapErrParameterInvalidMsg("Current rerank does not support grouping search") } t.SearchRequest.SubReqs = make([]*internalpb.SubSearchRequest, len(t.request.GetSubReqs())) t.queryInfos = make([]*planpb.QueryInfo, len(t.request.GetSubReqs())) queryFieldIDs := []int64{} for index, subReq := range t.request.GetSubReqs() { plan, queryInfo, offset, _, err := t.tryGeneratePlan(subReq.GetSearchParams(), subReq.GetDsl(), subReq.GetExprTemplateValues()) if err != nil { return err } ignoreGrowing := t.SearchRequest.IgnoreGrowing if !ignoreGrowing { // fetch ignore_growing from sub search param if not set in search request if ignoreGrowing, err = isIgnoreGrowing(subReq.GetSearchParams()); err != nil { return err } } internalSubReq := &internalpb.SubSearchRequest{ Dsl: subReq.GetDsl(), PlaceholderGroup: subReq.GetPlaceholderGroup(), DslType: subReq.GetDslType(), SerializedExprPlan: nil, Nq: subReq.GetNq(), PartitionIDs: nil, Topk: queryInfo.GetTopk(), Offset: offset, MetricType: queryInfo.GetMetricType(), GroupByFieldId: t.rankParams.GetGroupByFieldId(), GroupSize: t.rankParams.GetGroupSize(), IgnoreGrowing: ignoreGrowing, } // set analyzer name for sub search analyzer, err := funcutil.GetAttrByKeyFromRepeatedKV(AnalyzerKey, subReq.GetSearchParams()) if err == nil { internalSubReq.AnalyzerName = analyzer } internalSubReq.FieldId = queryInfo.GetQueryFieldId() if err := t.addHighlightTask(t.request.GetHighlighter(), internalSubReq.GetMetricType(), internalSubReq.FieldId, subReq.GetPlaceholderGroup(), internalSubReq.GetAnalyzerName()); err != nil { return err } queryFieldIDs = append(queryFieldIDs, internalSubReq.FieldId) // set PartitionIDs for sub search if t.partitionKeyMode { // isolation has tighter constraint, check first mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan) if mvErr != nil { return mvErr } partitionIDs, err2 := t.tryParsePartitionIDsFromPlan(plan) if err2 != nil { return err2 } if len(partitionIDs) > 0 { internalSubReq.PartitionIDs = partitionIDs t.partitionIDsSet.Upsert(partitionIDs...) } } else { internalSubReq.PartitionIDs = t.SearchRequest.GetPartitionIDs() } if t.needRequery { plan.OutputFieldIds = t.functionScore.GetAllInputFieldIDs() } else { primaryFieldSchema, err := t.schema.GetPkField() if err != nil { return err } allFieldIDs := typeutil.NewSet(t.SearchRequest.OutputFieldsId...) allFieldIDs.Insert(t.functionScore.GetAllInputFieldIDs()...) allFieldIDs.Insert(primaryFieldSchema.FieldID) plan.OutputFieldIds = allFieldIDs.Collect() plan.DynamicFields = t.userDynamicFields } plan.Namespace = t.request.Namespace internalSubReq.SerializedExprPlan, err = proto.Marshal(plan) if err != nil { return err } if typeutil.IsFieldSparseFloatVector(t.schema.CollectionSchema, internalSubReq.FieldId) { metrics.ProxySearchSparseNumNonZeros.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), t.collectionName, metrics.HybridSearchLabel, strconv.FormatInt(internalSubReq.FieldId, 10)).Observe(float64(typeutil.EstimateSparseVectorNNZFromPlaceholderGroup(internalSubReq.PlaceholderGroup, int(internalSubReq.GetNq())))) } t.SearchRequest.SubReqs[index] = internalSubReq t.queryInfos[index] = queryInfo log.Debug("proxy init search request", zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. } if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, queryFieldIDs) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-AdvancedSearch-call-function-udf") defer sp.End() exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema, nil, &models.ModelExtraInfo{ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), DBName: t.request.GetDbName()}) if err != nil { return err } sp.AddEvent("Create-function-udf") if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil { return err } sp.AddEvent("Call-function-udf") } t.SearchRequest.GroupByFieldId = t.rankParams.GetGroupByFieldId() t.SearchRequest.GroupSize = t.rankParams.GetGroupSize() if t.partitionKeyMode { t.SearchRequest.PartitionIDs = t.partitionIDsSet.Collect() } return nil } func (t *searchTask) fillResult() { limit := t.SearchRequest.GetTopk() - t.SearchRequest.GetOffset() resultSizeInsufficient := false for _, topk := range t.result.Results.Topks { if topk < limit { resultSizeInsufficient = true break } } t.resultSizeInsufficient = resultSizeInsufficient t.result.CollectionName = t.collectionName } func (t *searchTask) getBM25SearchTexts(placeholder []byte) ([]string, error) { pb := &commonpb.PlaceholderGroup{} proto.Unmarshal(placeholder, pb) if len(pb.Placeholders) != 1 || len(pb.Placeholders[0].Values) == 0 { return nil, merr.WrapErrParameterInvalidMsg("please provide varchar/text for BM25 Function based search") } holder := pb.Placeholders[0] if holder.Type != commonpb.PlaceholderType_VarChar { return nil, merr.WrapErrParameterInvalidMsg(fmt.Sprintf("please provide varchar/text for BM25 Function based search, got %s", holder.Type.String())) } texts := funcutil.GetVarCharFromPlaceholder(holder) return texts, nil } func (t *searchTask) createLexicalHighlighter(highlighter *commonpb.Highlighter, metricType string, annsField int64, placeholder []byte, analyzerName string) error { task := &highlightTask{ HighlightTask: &querypb.HighlightTask{}, } params := funcutil.KeyValuePair2Map(highlighter.GetParams()) if metricType != metric.BM25 { return merr.WrapErrParameterInvalidMsg(`Search highlight only support with metric type "BM25" but was: %s`, t.SearchRequest.GetMetricType()) } function, ok := getBM25FunctionOfAnnsField(annsField, t.schema.GetFunctions()) if !ok { return merr.WrapErrServiceInternal(`Search with highlight failed, input field of BM25 annsField not found`) } task.FieldId = function.InputFieldIds[0] task.FieldName = function.InputFieldNames[0] if value, ok := params[HighlightSearchTextKey]; ok { enable, err := strconv.ParseBool(value) if err != nil { return merr.WrapErrParameterInvalidMsg("unmarshal highlight_search_data as bool failed: %v", err) } // now only support highlight with search // so skip if highlight search not enable. if !enable { return nil } } // set pre_tags and post_tags if value, ok := params[PreTagsKey]; ok { tags := []string{} if err := json.Unmarshal([]byte(value), &tags); err != nil { return merr.WrapErrParameterInvalidMsg("unmarshal pre_tags as string list failed: %v", err) } if len(tags) == 0 { return merr.WrapErrParameterInvalidMsg("pre_tags cannot be empty list") } task.preTags = make([][]byte, len(tags)) for i, tag := range tags { task.preTags[i] = []byte(tag) } } else { task.preTags = [][]byte{[]byte(DefaultPreTag)} } if value, ok := params[PostTagsKey]; ok { tags := []string{} if err := json.Unmarshal([]byte(value), &tags); err != nil { return merr.WrapErrParameterInvalidMsg("unmarshal post_tags as string list failed: %v", err) } if len(tags) == 0 { return merr.WrapErrParameterInvalidMsg("post_tags cannot be empty list") } task.postTags = make([][]byte, len(tags)) for i, tag := range tags { task.postTags[i] = []byte(tag) } } else { task.postTags = [][]byte{[]byte(DefaultPostTag)} } // set bm25 search text as query texts texts, err := t.getBM25SearchTexts(placeholder) if err != nil { return err } task.Texts = texts task.SearchTextNum = int64(len(texts)) if analyzerName != "" { task.AnalyzerNames = []string{} for i := 0; i < len(texts); i++ { task.AnalyzerNames = append(task.AnalyzerNames, analyzerName) } } t.highlightTasks = append(t.highlightTasks, task) return nil } func (t *searchTask) addHighlightTask(highlighter *commonpb.Highlighter, metricType string, annsField int64, placeholder []byte, analyzerName string) error { if highlighter == nil { return nil } switch highlighter.GetType() { case commonpb.HighlightType_Lexical: return t.createLexicalHighlighter(highlighter, metricType, annsField, placeholder, analyzerName) default: return merr.WrapErrParameterInvalidMsg("unsupported highlight type: %v", highlighter.GetType()) } } func (t *searchTask) initSearchRequest(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "init search request") defer sp.End() log := log.Ctx(ctx).With(zap.Int64("collID", t.GetCollectionID()), zap.String("collName", t.collectionName)) plan, queryInfo, offset, isIterator, err := t.tryGeneratePlan(t.request.GetSearchParams(), t.request.GetDsl(), t.request.GetExprTemplateValues()) if err != nil { return err } if t.request.FunctionScore != nil { if t.functionScore, err = rerank.NewFunctionScore(t.schema.CollectionSchema, t.request.FunctionScore, &models.ModelExtraInfo{ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), DBName: t.request.GetDbName()}); err != nil { log.Warn("Failed to create function score", zap.Error(err)) return err } if !t.functionScore.IsSupportGroup() && queryInfo.GetGroupByFieldId() > 0 { return merr.WrapErrParameterInvalidMsg("Rerank %s does not support grouping search", t.functionScore.RerankName()) } } t.isIterator = isIterator t.SearchRequest.Offset = offset t.SearchRequest.FieldId = queryInfo.GetQueryFieldId() if t.partitionKeyMode { // isolation has tighter constraint, check first mvErr := setQueryInfoIfMvEnable(queryInfo, t, plan) if mvErr != nil { return mvErr } partitionIDs, err2 := t.tryParsePartitionIDsFromPlan(plan) if err2 != nil { return err2 } if len(partitionIDs) > 0 { t.SearchRequest.PartitionIDs = partitionIDs } } allFields := typeutil.GetAllFieldSchemas(t.schema.CollectionSchema) vectorOutputFields := lo.Filter(allFields, func(field *schemapb.FieldSchema, _ int) bool { return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType()) }) t.needRequery = len(vectorOutputFields) > 0 if t.needRequery { plan.OutputFieldIds = t.functionScore.GetAllInputFieldIDs() } else { primaryFieldSchema, err := t.schema.GetPkField() if err != nil { return err } allFieldIDs := typeutil.NewSet[int64](t.SearchRequest.OutputFieldsId...) allFieldIDs.Insert(t.functionScore.GetAllInputFieldIDs()...) allFieldIDs.Insert(primaryFieldSchema.FieldID) plan.OutputFieldIds = allFieldIDs.Collect() plan.DynamicFields = t.userDynamicFields } plan.Namespace = t.request.Namespace t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan) if err != nil { return err } if typeutil.IsFieldSparseFloatVector(t.schema.CollectionSchema, t.SearchRequest.FieldId) { metrics.ProxySearchSparseNumNonZeros.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), t.collectionName, metrics.SearchLabel, strconv.FormatInt(t.SearchRequest.FieldId, 10)).Observe(float64(typeutil.EstimateSparseVectorNNZFromPlaceholderGroup(t.request.PlaceholderGroup, int(t.request.GetNq())))) } t.SearchRequest.PlaceholderGroup = t.request.PlaceholderGroup t.SearchRequest.Topk = queryInfo.GetTopk() t.SearchRequest.MetricType = queryInfo.GetMetricType() t.queryInfos = append(t.queryInfos, queryInfo) t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 t.SearchRequest.GroupByFieldId = queryInfo.GroupByFieldId t.SearchRequest.GroupSize = queryInfo.GroupSize if t.SearchRequest.MetricType == metric.BM25 { analyzer, err := funcutil.GetAttrByKeyFromRepeatedKV(AnalyzerKey, t.request.GetSearchParams()) if err == nil { t.SearchRequest.AnalyzerName = analyzer } } if err := t.addHighlightTask(t.request.GetHighlighter(), t.SearchRequest.MetricType, t.SearchRequest.FieldId, t.request.GetPlaceholderGroup(), t.SearchRequest.GetAnalyzerName()); err != nil { return err } if embedding.HasNonBM25Functions(t.schema.CollectionSchema.Functions, []int64{queryInfo.GetQueryFieldId()}) { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-call-function-udf") defer sp.End() exec, err := embedding.NewFunctionExecutor(t.schema.CollectionSchema, nil, &models.ModelExtraInfo{ClusterID: paramtable.Get().CommonCfg.ClusterPrefix.GetValue(), DBName: t.request.GetDbName()}) if err != nil { return err } sp.AddEvent("Create-function-udf") if err := exec.ProcessSearch(ctx, t.SearchRequest); err != nil { return err } sp.AddEvent("Call-function-udf") } log.Debug("proxy init search request", zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.Stringer("plan", plan)) // may be very large if large term passed. return nil } func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string, exprTemplateValues map[string]*schemapb.TemplateValue) (*planpb.PlanNode, *planpb.QueryInfo, int64, bool, error) { annsFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, params) if err != nil || len(annsFieldName) == 0 { vecFields := typeutil.GetVectorFieldSchemas(t.schema.CollectionSchema) if len(vecFields) == 0 { return nil, nil, 0, false, errors.New(AnnsFieldKey + " not found in schema") } if enableMultipleVectorFields && len(vecFields) > 1 { return nil, nil, 0, false, errors.New("multiple anns_fields exist, please specify a anns_field in search_params") } annsFieldName = vecFields[0].Name } searchInfo, err := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams) if err != nil { return nil, nil, 0, false, err } if searchInfo.collectionID > 0 && searchInfo.collectionID != t.GetCollectionID() { return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+ "alias or database may have been changed: %d", searchInfo.collectionID, t.GetCollectionID()) } annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName) if searchInfo.planInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector { return nil, nil, 0, false, errors.New("not support search_group_by operation based on binary vector column") } searchInfo.planInfo.QueryFieldId = annField.GetFieldID() start := time.Now() plan, planErr := planparserv2.CreateSearchPlan(t.schema.schemaHelper, dsl, annsFieldName, searchInfo.planInfo, exprTemplateValues, t.request.GetFunctionScore()) if planErr != nil { log.Ctx(t.ctx).Warn("failed to create query plan", zap.Error(planErr), zap.String("dsl", dsl), // may be very large if large term passed. zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo)) metrics.ProxyParseExpressionLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "search", metrics.FailLabel).Observe(float64(time.Since(start).Milliseconds())) return nil, nil, 0, false, merr.WrapErrParameterInvalidMsg("failed to create query plan: %v", planErr) } metrics.ProxyParseExpressionLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), "search", metrics.SuccessLabel).Observe(float64(time.Since(start).Milliseconds())) log.Ctx(t.ctx).Debug("create query plan", zap.String("dsl", t.request.Dsl), // may be very large if large term passed. zap.String("anns field", annsFieldName), zap.Any("query info", searchInfo.planInfo)) return plan, searchInfo.planInfo, searchInfo.offset, searchInfo.isIterator, nil } func (t *searchTask) tryParsePartitionIDsFromPlan(plan *planpb.PlanNode) ([]int64, error) { expr, err := exprutil.ParseExprFromPlan(plan) if err != nil { log.Ctx(t.ctx).Warn("failed to parse expr", zap.Error(err)) return nil, err } partitionKeys := exprutil.ParseKeys(expr, exprutil.PartitionKey) hashedPartitionNames, err := assignPartitionKeys(t.ctx, t.request.GetDbName(), t.collectionName, partitionKeys) if err != nil { log.Ctx(t.ctx).Warn("failed to assign partition keys", zap.Error(err)) return nil, err } if len(hashedPartitionNames) > 0 { // translate partition name to partition ids. Use regex-pattern to match partition name. PartitionIDs, err2 := getPartitionIDs(t.ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames) if err2 != nil { log.Ctx(t.ctx).Warn("failed to get partition ids", zap.Error(err2)) return nil, err2 } return PartitionIDs, nil } return nil, nil } func (t *searchTask) Execute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-Execute") defer sp.End() log := log.Ctx(ctx).WithLazy(zap.Int64("nq", t.SearchRequest.GetNq())) tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute search %d", t.ID())) defer tr.CtxElapse(ctx, "done") err := t.lb.Execute(ctx, shardclient.CollectionWorkLoad{ Db: t.request.GetDbName(), CollectionID: t.SearchRequest.CollectionID, CollectionName: t.collectionName, Nq: t.Nq, Exec: t.searchShard, }) if err != nil { log.Warn("search execute failed", zap.Error(err)) return errors.Wrap(err, "failed to search") } log.Debug("Search Execute done.", zap.Int64("collection", t.GetCollectionID()), zap.Int64s("partitionIDs", t.GetPartitionIDs())) return nil } // find the last bound based on reduced results and metric type // only support nq == 1, for search iterator v2 func getLastBound(result *milvuspb.SearchResults, incomingLastBound *float32, metricType string) float32 { len := len(result.Results.Scores) if len > 0 && result.GetResults().GetNumQueries() == 1 { return result.Results.Scores[len-1] } // if no results found and incoming last bound is not nil, return it if incomingLastBound != nil { return *incomingLastBound } // if no results found and it is the first call, return the closest bound if metric.PositivelyRelated(metricType) { return math.MaxFloat32 } return -math.MaxFloat32 } func (t *searchTask) PostExecute(ctx context.Context) error { ctx, sp := otel.Tracer(typeutil.ProxyRole).Start(ctx, "Proxy-Search-PostExecute") defer sp.End() tr := timerecord.NewTimeRecorder("searchTask PostExecute") defer func() { tr.CtxElapse(ctx, "done") }() log := log.Ctx(ctx).With(zap.Int64("nq", t.SearchRequest.GetNq())) toReduceResults, err := t.collectSearchResults(ctx) if err != nil { log.Warn("failed to collect search results", zap.Error(err)) return err } t.queryChannelsTs = make(map[string]uint64) t.relatedDataSize = 0 isTopkReduce := false isRecallEvaluation := false storageCost := segcore.StorageCost{} for _, r := range toReduceResults { if r.GetIsTopkReduce() { isTopkReduce = true } if r.GetIsRecallEvaluation() { isRecallEvaluation = true } t.relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize() for ch, ts := range r.GetChannelsMvcc() { t.queryChannelsTs[ch] = ts } storageCost.ScannedRemoteBytes += r.GetScannedRemoteBytes() storageCost.ScannedTotalBytes += r.GetScannedTotalBytes() } t.isTopkReduce = isTopkReduce t.isRecallEvaluation = isRecallEvaluation // call pipeline pipeline, err := newSearchPipeline(t) if err != nil { log.Warn("Faild to create post process pipeline") return err } if t.result, t.storageCost, err = pipeline.Run(ctx, sp, toReduceResults, storageCost); err != nil { return err } t.fillResult() t.result.Results.OutputFields = t.userOutputFields reconstructStructFieldDataForSearch(t.result, t.schema.CollectionSchema) t.result.CollectionName = t.request.GetCollectionName() primaryFieldSchema, _ := t.schema.GetPkField() if t.userRequestedPkFieldExplicitly { t.result.Results.OutputFields = append(t.result.Results.OutputFields, primaryFieldSchema.GetName()) var scalars *schemapb.ScalarField if primaryFieldSchema.GetDataType() == schemapb.DataType_Int64 { scalars = &schemapb.ScalarField{ Data: &schemapb.ScalarField_LongData{ LongData: t.result.Results.Ids.GetIntId(), }, } } else { scalars = &schemapb.ScalarField{ Data: &schemapb.ScalarField_StringData{ StringData: t.result.Results.Ids.GetStrId(), }, } } pkFieldData := &schemapb.FieldData{ FieldName: primaryFieldSchema.GetName(), FieldId: primaryFieldSchema.GetFieldID(), Type: primaryFieldSchema.GetDataType(), IsDynamic: false, Field: &schemapb.FieldData_Scalars{ Scalars: scalars, }, } t.result.Results.FieldsData = append(t.result.Results.FieldsData, pkFieldData) } t.result.Results.PrimaryFieldName = primaryFieldSchema.GetName() if t.isIterator && len(t.queryInfos) == 1 && t.queryInfos[0] != nil { if iterInfo := t.queryInfos[0].GetSearchIteratorV2Info(); iterInfo != nil { t.result.Results.SearchIteratorV2Results = &schemapb.SearchIteratorV2Results{ Token: iterInfo.GetToken(), LastBound: getLastBound(t.result, iterInfo.LastBound, getMetricType(toReduceResults)), } } } fieldsData := t.result.GetResults().GetFieldsData() for i, fieldData := range fieldsData { if fieldData.Type == schemapb.DataType_Geometry { if err := validateGeometryFieldSearchResult(&fieldsData[i]); err != nil { log.Warn("fail to validate geometry field search result", zap.Error(err)) return err } } } if t.result.GetResults().GetGroupByFieldValue() != nil && t.result.GetResults().GetGroupByFieldValue().GetType() == schemapb.DataType_Geometry { if err := validateGeometryFieldSearchResult(&t.result.Results.GroupByFieldValue); err != nil { log.Warn("fail to validate geometry field search result", zap.Error(err)) return err } } if t.isIterator && t.request.GetGuaranteeTimestamp() == 0 { // first page for iteration, need to set up sessionTs for iterator t.result.SessionTs = getMaxMvccTsFromChannels(t.queryChannelsTs, t.BeginTs()) } metrics.ProxyReduceResultLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).Observe(float64(tr.RecordSpan().Milliseconds())) timeFields := parseTimeFields(t.request.SearchParams) if timeFields != nil { log.Debug("extracting fields for timestamptz", zap.Strings("fields", timeFields)) err = extractFieldsFromResults(t.result.GetResults().GetFieldsData(), t.resolvedTimezoneStr, timeFields) if err != nil { log.Warn("fail to extract fields for timestamptz", zap.Error(err)) return err } } else { err = timestamptzUTC2IsoStr(t.result.GetResults().GetFieldsData(), t.resolvedTimezoneStr) if err != nil { log.Warn("fail to translate timestamp", zap.Error(err)) return err } } log.Debug("Search post execute done", zap.Int64("collection", t.GetCollectionID()), zap.Int64s("partitionIDs", t.GetPartitionIDs())) return nil } func (t *searchTask) searchShard(ctx context.Context, nodeID int64, qn types.QueryNodeClient, channel string) error { searchReq := typeutil.Clone(t.SearchRequest) searchReq.GetBase().TargetID = nodeID req := &querypb.SearchRequest{ Req: searchReq, DmlChannels: []string{channel}, Scope: querypb.DataScope_All, TotalChannelNum: int32(1), } log := log.Ctx(ctx).With(zap.Int64("collection", t.GetCollectionID()), zap.Int64s("partitionIDs", t.GetPartitionIDs()), zap.Int64("nodeID", nodeID), zap.String("channel", channel)) var result *internalpb.SearchResults var err error result, err = qn.Search(ctx, req) if err != nil { log.Warn("QueryNode search return error", zap.Error(err)) // globalMetaCache.DeprecateShardCache(t.request.GetDbName(), t.collectionName) t.shardClientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName) return err } if result.GetStatus().GetErrorCode() == commonpb.ErrorCode_NotShardLeader { log.Warn("QueryNode is not shardLeader") t.shardClientMgr.DeprecateShardCache(t.request.GetDbName(), t.collectionName) return merr.Error(result.GetStatus()) } if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success { log.Warn("QueryNode search result error", zap.String("reason", result.GetStatus().GetReason())) return errors.Wrapf(merr.Error(result.GetStatus()), "fail to search on QueryNode %d", nodeID) } if t.resultBuf != nil { t.resultBuf.Insert(result) } t.lb.UpdateCostMetrics(nodeID, result.CostAggregation) return nil } func (t *searchTask) estimateResultSize(nq int64, topK int64) (int64, error) { vectorOutputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { return lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType()) }) for _, structArrayField := range t.schema.GetStructArrayFields() { for _, field := range structArrayField.GetFields() { if lo.Contains(t.translatedOutputFields, field.GetName()) && typeutil.IsVectorType(field.GetDataType()) { vectorOutputFields = append(vectorOutputFields, field) } } } // Currently, we get vectors by requery. Once we support getting vectors from search, // searches with small result size could no longer need requery. if len(vectorOutputFields) > 0 { return math.MaxInt64, nil } // If no vector field as output, no need to requery. return 0, nil //outputFields := lo.Filter(t.schema.GetFields(), func(field *schemapb.FieldSchema, _ int) bool { // return lo.Contains(t.translatedOutputFields, field.GetName()) //}) //sizePerRecord, err := typeutil.EstimateSizePerRecord(&schemapb.CollectionSchema{Fields: outputFields}) //if err != nil { // return 0, err //} //return int64(sizePerRecord) * nq * topK, nil } func (t *searchTask) collectSearchResults(ctx context.Context) ([]*internalpb.SearchResults, error) { select { case <-t.TraceCtx().Done(): log.Ctx(ctx).Warn("search task wait to finish timeout!") return nil, fmt.Errorf("search task wait to finish timeout, msgID=%d", t.ID()) default: toReduceResults := make([]*internalpb.SearchResults, 0) log.Ctx(ctx).Debug("all searches are finished or canceled") t.resultBuf.Range(func(res *internalpb.SearchResults) bool { toReduceResults = append(toReduceResults, res) log.Ctx(ctx).Debug("proxy receives one search result", zap.Int64("sourceID", res.GetBase().GetSourceID())) return true }) return toReduceResults, nil } } func (t *searchTask) TraceCtx() context.Context { return t.ctx } func (t *searchTask) ID() UniqueID { return t.Base.MsgID } func (t *searchTask) SetID(uid UniqueID) { t.Base.MsgID = uid } func (t *searchTask) Name() string { return SearchTaskName } func (t *searchTask) Type() commonpb.MsgType { return t.Base.MsgType } func (t *searchTask) BeginTs() Timestamp { return t.Base.Timestamp } func (t *searchTask) EndTs() Timestamp { return t.Base.Timestamp } func (t *searchTask) SetTs(ts Timestamp) { t.Base.Timestamp = ts } func (t *searchTask) OnEnqueue() error { t.Base = commonpbutil.NewMsgBase() t.Base.MsgType = commonpb.MsgType_Search t.Base.SourceID = paramtable.GetNodeID() return nil }