fix: fix requery without partitionIDs in hybrid search (#30444)

issue: #30412 
Signed-off-by: xige-16 <xi.ge@zilliz.com>

Signed-off-by: xige-16 <xi.ge@zilliz.com>
This commit is contained in:
xige-16 2024-02-02 16:47:13 +08:00 committed by GitHub
parent dcdf85977c
commit 0a78b38bb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 50 additions and 21 deletions

View File

@ -60,7 +60,6 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
} }
t.SearchRequest.OutputFieldsId = outputFieldIDs t.SearchRequest.OutputFieldsId = outputFieldIDs
partitionNames := t.request.GetPartitionNames()
if t.request.GetDslType() == commonpb.DslType_BoolExprV1 { if t.request.GetDslType() == commonpb.DslType_BoolExprV1 {
annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams()) annsField, err := funcutil.GetAttrByKeyFromRepeatedKV(AnnsFieldKey, t.request.GetSearchParams())
if err != nil || len(annsField) == 0 { if err != nil || len(annsField) == 0 {
@ -109,7 +108,14 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
return err return err
} }
partitionNames = append(partitionNames, hashedPartitionNames...) if len(hashedPartitionNames) > 0 {
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, hashedPartitionNames)
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
}
}
} }
plan.OutputFieldIds = outputFieldIDs plan.OutputFieldIds = outputFieldIDs
@ -138,13 +144,6 @@ func initSearchRequest(ctx context.Context, t *searchTask) error {
zap.Stringer("plan", plan)) // may be very large if large term passed. zap.Stringer("plan", plan)) // may be very large if large term passed.
} }
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.SearchRequest.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), t.collectionName, partitionNames)
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
}
if deadline, ok := t.TraceCtx().Deadline(); ok { if deadline, ok := t.TraceCtx().Deadline(); ok {
t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0) t.SearchRequest.TimeoutTimestamp = tsoutil.ComposeTSByTime(deadline, 0)
} }

View File

@ -42,6 +42,7 @@ type hybridSearchTask struct {
tr *timerecord.TimeRecorder tr *timerecord.TimeRecorder
schema *schemaInfo schema *schemaInfo
requery bool requery bool
partitionKeyMode bool
userOutputFields []string userOutputFields []string
@ -51,6 +52,8 @@ type hybridSearchTask struct {
resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult] resultBuf *typeutil.ConcurrentSet[*querypb.HybridSearchResult]
multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults] multipleRecallResults *typeutil.ConcurrentSet[*milvuspb.SearchResults]
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
reScorers []reScorer reScorers []reScorer
queryChannelsTs map[string]Timestamp queryChannelsTs map[string]Timestamp
rankParams *rankParams rankParams *rankParams
@ -97,14 +100,26 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
return err return err
} }
partitionKeyMode, err := isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName) t.partitionKeyMode, err = isPartitionKeyMode(ctx, t.request.GetDbName(), collectionName)
if err != nil { if err != nil {
log.Warn("is partition key mode failed", zap.Error(err)) log.Warn("is partition key mode failed", zap.Error(err))
return err return err
} }
if partitionKeyMode && len(t.request.GetPartitionNames()) != 0 { if t.partitionKeyMode {
if len(t.request.GetPartitionNames()) != 0 {
return errors.New("not support manually specifying the partition names if partition key mode is used") return errors.New("not support manually specifying the partition names if partition key mode is used")
} }
t.partitionIDsSet = typeutil.NewConcurrentSet[UniqueID]()
}
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
}
}
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false) t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
if err != nil { if err != nil {
@ -176,6 +191,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
ReqID: paramtable.GetNodeID(), ReqID: paramtable.GetNodeID(),
DbID: 0, // todo DbID: 0, // todo
CollectionID: collID, CollectionID: collID,
PartitionIDs: t.GetPartitionIDs(),
}, },
request: searchReq, request: searchReq,
schema: t.schema, schema: t.schema,
@ -184,7 +200,7 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
node: t.node, node: t.node,
lb: t.lb, lb: t.lb,
partitionKeyMode: partitionKeyMode, partitionKeyMode: t.partitionKeyMode,
resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](), resultBuf: typeutil.NewConcurrentSet[*internalpb.SearchResults](),
} }
err := initSearchRequest(ctx, t.searchTasks[index]) err := initSearchRequest(ctx, t.searchTasks[index])
@ -192,6 +208,9 @@ func (t *hybridSearchTask) PreExecute(ctx context.Context) error {
log.Debug("init hybrid search request failed", zap.Error(err)) log.Debug("init hybrid search request failed", zap.Error(err))
return err return err
} }
if t.partitionKeyMode {
t.partitionIDsSet.Upsert(t.searchTasks[index].GetPartitionIDs()...)
}
} }
log.Debug("hybrid search preExecute done.", log.Debug("hybrid search preExecute done.",
@ -208,6 +227,9 @@ func (t *hybridSearchTask) hybridSearchShard(ctx context.Context, nodeID int64,
} }
hybridSearchReq := typeutil.Clone(t.HybridSearchRequest) hybridSearchReq := typeutil.Clone(t.HybridSearchRequest)
hybridSearchReq.GetBase().TargetID = nodeID hybridSearchReq.GetBase().TargetID = nodeID
if t.partitionKeyMode {
t.PartitionIDs = t.partitionIDsSet.Collect()
}
req := &querypb.HybridSearchRequest{ req := &querypb.HybridSearchRequest{
Req: hybridSearchReq, Req: hybridSearchReq,
DmlChannels: []string{channel}, DmlChannels: []string{channel},
@ -437,8 +459,7 @@ func (t *hybridSearchTask) Requery() error {
}, },
} }
// TODO:silverxia move partitionIDs to hybrid search level return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, t.GetPartitionIDs())
return doRequery(t.ctx, t.CollectionID, t.node, t.schema.CollectionSchema, queryReq, t.result, t.queryChannelsTs, []int64{})
} }
func rankSearchResultData(ctx context.Context, func rankSearchResultData(ctx context.Context,

View File

@ -268,6 +268,15 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
return errors.New("not support manually specifying the partition names if partition key mode is used") return errors.New("not support manually specifying the partition names if partition key mode is used")
} }
if !t.partitionKeyMode && len(t.request.GetPartitionNames()) > 0 {
// translate partition name to partition ids. Use regex-pattern to match partition name.
t.PartitionIDs, err = getPartitionIDs(ctx, t.request.GetDbName(), collectionName, t.request.GetPartitionNames())
if err != nil {
log.Warn("failed to get partition ids", zap.Error(err))
return err
}
}
t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false) t.request.OutputFields, t.userOutputFields, err = translateOutputFields(t.request.OutputFields, t.schema, false)
if err != nil { if err != nil {
log.Warn("translate output fields failed", zap.Error(err)) log.Warn("translate output fields failed", zap.Error(err))