mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
fix: fix requery without partitionIDs in hybrid search (#30444)
issue: #30412 Signed-off-by: xige-16 <xi.ge@zilliz.com> Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
parent
dcdf85977c
commit
0a78b38bb8
@ -60,7 +60,6 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
|
|||||||
}
|
}
|
||||||
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
t.SearchRequest.OutputFieldsId = outputFieldIDs
|
||||||
|
|
||||||
partitionNames := t.request.GetPartitionNames()
|
|
||||||
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
|
||||||
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
|
||||||
if err != nil || len(annsField) == 0 {
|
if err != nil || len(annsField) == 0 {
|
||||||
@ -109,7 +108,14 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
partitionNames = append(partitionNames, hashedPartitionNames...)
|
if len(hashedPartitionNames) > 0 {
|
||||||
|
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||||
|
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to get partition ids", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
plan.OutputFieldIds = outputFieldIDs
|
plan.OutputFieldIds = outputFieldIDs
|
||||||
@ -138,13 +144,6 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
|
|||||||
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
zap.Stringer("plan", plan)) // may be very large if large term passed.
|
||||||
}
|
}
|
||||||
|
|
||||||
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
|
||||||
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, partitionNames)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn("failed to get partition ids", zap.Error(err))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
if deadline, ok := t.TraceCtx().Deadline(); ok {
|
||||||
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -42,6 +42,7 @@ type hybridSearchTask struct {
|
|||||||
tr *timerecord.TimeRecorder
|
tr *timerecord.TimeRecorder
|
||||||
schema *schemaInfo
|
schema *schemaInfo
|
||||||
requery bool
|
requery bool
|
||||||
|
partitionKeyMode bool
|
||||||
|
|
||||||
userOutputFields []string
|
userOutputFields []string
|
||||||
|
|
||||||
@ -51,6 +52,8 @@ type hybridSearchTask struct {
|
|||||||
|
|
||||||
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
|
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
|
||||||
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
|
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
|
||||||
|
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
|
||||||
|
|
||||||
reScorers []reScorer
|
reScorers []reScorer
|
||||||
queryChannelsTs map[string]Timestamp
|
queryChannelsTs map[string]Timestamp
|
||||||
rankParams *rankParams
|
rankParams *rankParams
|
||||||
@ -97,14 +100,26 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("is partition key mode failed", zap.Error(err))
|
log.Warn("is partition key mode failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 {
|
if t.partitionKeyMode {
|
||||||
|
if len(t.request.GetPartitionNames()) != 0 {
|
||||||
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
||||||
}
|
}
|
||||||
|
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
|
||||||
|
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||||
|
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to get partition ids", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
|
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -176,6 +191,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
ReqID: paramtable.GetNodeID(),
|
ReqID: paramtable.GetNodeID(),
|
||||||
DbID: 0, // todo
|
DbID: 0, // todo
|
||||||
CollectionID: collID,
|
CollectionID: collID,
|
||||||
|
PartitionIDs: t.GetPartitionIDs(),
|
||||||
},
|
},
|
||||||
request: searchReq,
|
request: searchReq,
|
||||||
schema: t.schema,
|
schema: t.schema,
|
||||||
@ -184,7 +200,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
node: t.node,
|
node: t.node,
|
||||||
lb: t.lb,
|
lb: t.lb,
|
||||||
|
|
||||||
partitionKeyMode: partitionKeyMode,
|
partitionKeyMode: t.partitionKeyMode,
|
||||||
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
|
||||||
}
|
}
|
||||||
err := initSearchRequest(ctx, t.searchTasks[index])
|
err := initSearchRequest(ctx, t.searchTasks[index])
|
||||||
@ -192,6 +208,9 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
|
|||||||
log.Debug("init hybrid search request failed", zap.Error(err))
|
log.Debug("init hybrid search request failed", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if t.partitionKeyMode {
|
||||||
|
t.partitionIDsSet.Upsert(t.searchTasks[index].GetPartitionIDs()...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug("hybrid search preExecute done.",
|
log.Debug("hybrid search preExecute done.",
|
||||||
@ -208,6 +227,9 @@ func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64,
|
|||||||
}
|
}
|
||||||
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
|
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
|
||||||
hybridSearchReq.GetBase().TargetID = nodeID
|
hybridSearchReq.GetBase().TargetID = nodeID
|
||||||
|
if t.partitionKeyMode {
|
||||||
|
t.PartitionIDs = t.partitionIDsSet.Collect()
|
||||||
|
}
|
||||||
req := &querypb.HybridSearchRequest{
|
req := &querypb.HybridSearchRequest{
|
||||||
Req: hybridSearchReq,
|
Req: hybridSearchReq,
|
||||||
DmlChannels: []string{channel},
|
DmlChannels: []string{channel},
|
||||||
@ -437,8 +459,7 @@ func (t *hybridSearchTask) Requery() error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO:silverxia move partitionIDs to hybrid search level
|
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
|
||||||
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func rankSearchResultData(ctx context.Context,
|
func rankSearchResultData(ctx context.Context,
|
||||||
|
|||||||
@ -268,6 +268,15 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
return errors.New("not support manually specifying the partition names if partition key mode is used")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
|
||||||
|
// translate partition name to partition ids. Use regex-pattern to match partition name.
|
||||||
|
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("failed to get partition ids", zap.Error(err))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
|
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn("translate output fields failed", zap.Error(err))
|
log.Warn("translate output fields failed", zap.Error(err))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user