diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index b247f7ee18..db5b16e1de 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -60,7 +60,6 @@ func initSearchRequest(ctx context.Context, t *searchTask) error { } t.SearchRequest.OutputFieldsId = outputFieldIDs - partitionNames := t.request.GetPartitionNames() if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) if err != nil || len(annsField) == 0 { @@ -109,7 +108,14 @@ func initSearchRequest(ctx context.Context, t *searchTask) error { return err } - partitionNames = append(partitionNames, hashedPartitionNames...) + if len(hashedPartitionNames) > 0 { + // translate partition name to partition ids. Use regex-pattern to match partition name. + t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames) + if err != nil { + log.Warn("failed to get partition ids", zap.Error(err)) + return err + } + } } plan.OutputFieldIds = outputFieldIDs @@ -138,13 +144,6 @@ func initSearchRequest(ctx context.Context, t *searchTask) error { zap.Stringer("plan", plan)) // may be very large if large term passed. } - // translate partition name to partition ids. Use regex-pattern to match partition name. - t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, partitionNames) - if err != nil { - log.Warn("failed to get partition ids", zap.Error(err)) - return err - } - if deadline, ok := t.TraceCtx().Deadline(); ok { t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0) } diff --git a/internal/proxy/task_hybrid_search.go b/internal/proxy/task_hybrid_search.go index 8bb98c71c7..33ec310e85 100644 --- a/internal/proxy/task_hybrid_search.go +++ b/internal/proxy/task_hybrid_search.go @@ -39,9 +39,10 @@ type hybridSearchTask struct { request *milvuspb.HybridSearchRequest searchTasks []*searchTask - tr *timerecord.TimeRecorder - schema *schemaInfo - requery bool + tr *timerecord.TimeRecorder + schema *schemaInfo + requery bool + partitionKeyMode bool userOutputFields []string @@ -51,9 +52,11 @@ type hybridSearchTask struct { resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult] multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults] - reScorers []reScorer - queryChannelsTs map[string]Timestamp - rankParams *rankParams + partitionIDsSet *typeutil.ConcurrentSet[UniqueID] + + reScorers []reScorer + queryChannelsTs map[string]Timestamp + rankParams *rankParams } func (t *hybridSearchTask) PreExecute(ctx context.Context) error { @@ -97,13 +100,25 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { return err } - partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName) + t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName) if err != nil { log.Warn("is partition key mode failed", zap.Error(err)) return err } - if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { - return errors.New("not support manually specifying the partition names if partition key mode is used") + if t.partitionKeyMode { + if len(t.request.GetPartitionNames()) != 0 { + return errors.New("not support manually specifying the partition names if partition key mode is used") + } + t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]() + } + + if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 { + // translate partition name to partition ids. Use regex-pattern to match partition name. + t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames()) + if err != nil { + log.Warn("failed to get partition ids", zap.Error(err)) + return err + } } t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false) @@ -176,6 +191,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { ReqID: paramtable.GetNodeID(), DbID: 0, // todo CollectionID: collID, + PartitionIDs: t.GetPartitionIDs(), }, request: searchReq, schema: t.schema, @@ -184,7 +200,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { node: t.node, lb: t.lb, - partitionKeyMode: partitionKeyMode, + partitionKeyMode: t.partitionKeyMode, resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](), } err := initSearchRequest(ctx, t.searchTasks[index]) @@ -192,6 +208,9 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error { log.Debug("init hybrid search request failed", zap.Error(err)) return err } + if t.partitionKeyMode { + t.partitionIDsSet.Upsert(t.searchTasks[index].GetPartitionIDs()...) + } } log.Debug("hybrid search preExecute done.", @@ -208,6 +227,9 @@ func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64, } hybridSearchReq := typeutil.Clone(t.HybridSearchRequest) hybridSearchReq.GetBase().TargetID = nodeID + if t.partitionKeyMode { + t.PartitionIDs = t.partitionIDsSet.Collect() + } req := &querypb.HybridSearchRequest{ Req: hybridSearchReq, DmlChannels: []string{channel}, @@ -437,8 +459,7 @@ func (t *hybridSearchTask) Requery() error { }, } - // TODO:silverxia move partitionIDs to hybrid search level - return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{}) + return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs()) } func rankSearchResultData(ctx context.Context, diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 5af8a19d79..b888b198cc 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -268,6 +268,15 @@ func (t *searchTask) PreExecute(ctx context.Context) error { return errors.New("not support manually specifying the partition names if partition key mode is used") } + if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 { + // translate partition name to partition ids. Use regex-pattern to match partition name. + t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames()) + if err != nil { + log.Warn("failed to get partition ids", zap.Error(err)) + return err + } + } + t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false) if err != nil { log.Warn("translate output fields failed", zap.Error(err))