mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
feat: support search highlight with queries (#45736)
Previously, search with highlight only supported using BM25 search text as the highlight target. This PR adds support for highlighting with user-defined queries. relate: https://github.com/milvus-io/milvus/issues/42589 --------- Signed-off-by: aoiasd <zhicheng.yue@zilliz.com>
This commit is contained in:
parent
b886b14291
commit
7d19c40e3c
394
internal/proxy/highlighter.go
Normal file
394
internal/proxy/highlighter.go
Normal file
@ -0,0 +1,394 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/cockroachdb/errors"
|
||||||
|
"github.com/samber/lo"
|
||||||
|
"go.opentelemetry.io/otel/trace"
|
||||||
|
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
|
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/funcutil"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PreTagsKey = "pre_tags"
|
||||||
|
PostTagsKey = "post_tags"
|
||||||
|
HighlightSearchTextKey = "highlight_search_text"
|
||||||
|
HighlightQueryKey = "queries"
|
||||||
|
FragmentOffsetKey = "fragment_offset"
|
||||||
|
FragmentSizeKey = "fragment_size"
|
||||||
|
FragmentNumKey = "num_of_fragments"
|
||||||
|
DefaultFragmentSize = 100
|
||||||
|
DefaultFragmentNum = 5
|
||||||
|
DefaultPreTag = "<em>"
|
||||||
|
DefaultPostTag = "</em>"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Highlighter interface {
|
||||||
|
AsSearchPipelineOperator(t *searchTask) (operator, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// highlight task for one field
|
||||||
|
type highlightTask struct {
|
||||||
|
*querypb.HighlightTask
|
||||||
|
preTags [][]byte
|
||||||
|
postTags [][]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type highlightQuery struct {
|
||||||
|
text string
|
||||||
|
fieldName string
|
||||||
|
highlightType querypb.HighlightQueryType
|
||||||
|
}
|
||||||
|
|
||||||
|
type LexicalHighlighter struct {
|
||||||
|
tasks map[int64]*highlightTask // fieldID -> highlightTask
|
||||||
|
// option for all highlight task
|
||||||
|
// TODO: support set option for each task
|
||||||
|
preTags [][]byte
|
||||||
|
postTags [][]byte
|
||||||
|
highlightSearch bool
|
||||||
|
options *querypb.HighlightOptions
|
||||||
|
queries []*highlightQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// add highlight task with search
|
||||||
|
// must used before addTaskWithQuery
|
||||||
|
func (h *LexicalHighlighter) addTaskWithSearchText(fieldID int64, fieldName string, analyzerName string, texts []string) error {
|
||||||
|
_, ok := h.tasks[fieldID]
|
||||||
|
if ok {
|
||||||
|
return merr.WrapErrParameterInvalidMsg("not support hybrid search with highlight now. fieldID: %d", fieldID)
|
||||||
|
}
|
||||||
|
|
||||||
|
task := &highlightTask{
|
||||||
|
preTags: h.preTags,
|
||||||
|
postTags: h.postTags,
|
||||||
|
HighlightTask: &querypb.HighlightTask{
|
||||||
|
FieldName: fieldName,
|
||||||
|
FieldId: fieldID,
|
||||||
|
Options: h.options,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h.tasks[fieldID] = task
|
||||||
|
|
||||||
|
task.Texts = texts
|
||||||
|
task.SearchTextNum = int64(len(texts))
|
||||||
|
if analyzerName != "" {
|
||||||
|
task.AnalyzerNames = []string{}
|
||||||
|
for i := 0; i < len(texts); i++ {
|
||||||
|
task.AnalyzerNames = append(task.AnalyzerNames, analyzerName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *LexicalHighlighter) addTaskWithQuery(fieldID int64, query *highlightQuery) {
|
||||||
|
task, ok := h.tasks[fieldID]
|
||||||
|
if !ok {
|
||||||
|
task = &highlightTask{
|
||||||
|
HighlightTask: &querypb.HighlightTask{
|
||||||
|
Texts: []string{},
|
||||||
|
FieldId: fieldID,
|
||||||
|
FieldName: query.fieldName,
|
||||||
|
Options: h.options,
|
||||||
|
},
|
||||||
|
preTags: h.preTags,
|
||||||
|
postTags: h.postTags,
|
||||||
|
}
|
||||||
|
h.tasks[fieldID] = task
|
||||||
|
}
|
||||||
|
|
||||||
|
task.Texts = append(task.Texts, query.text)
|
||||||
|
task.Queries = append(task.Queries, &querypb.HighlightQuery{
|
||||||
|
Type: query.highlightType,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *LexicalHighlighter) AsSearchPipelineOperator(t *searchTask) (operator, error) {
|
||||||
|
// add query to highlight tasks
|
||||||
|
for _, query := range h.queries {
|
||||||
|
fieldID, ok := t.schema.MapFieldID(query.fieldName)
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("highlight field not found in schema: %s", query.fieldName)
|
||||||
|
}
|
||||||
|
h.addTaskWithQuery(fieldID, query)
|
||||||
|
}
|
||||||
|
return newLexicalHighlightOperator(t, lo.Values(h.tasks))
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLexicalHighlighter(highlighter *commonpb.Highlighter) (*LexicalHighlighter, error) {
|
||||||
|
params := funcutil.KeyValuePair2Map(highlighter.GetParams())
|
||||||
|
h := &LexicalHighlighter{
|
||||||
|
tasks: make(map[int64]*highlightTask),
|
||||||
|
options: &querypb.HighlightOptions{},
|
||||||
|
}
|
||||||
|
|
||||||
|
// set pre_tags and post_tags
|
||||||
|
if value, ok := params[PreTagsKey]; ok {
|
||||||
|
tags := []string{}
|
||||||
|
if err := json.Unmarshal([]byte(value), &tags); err != nil {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal pre_tags as string array failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(tags) == 0 {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("pre_tags cannot be empty list")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.preTags = make([][]byte, len(tags))
|
||||||
|
for i, tag := range tags {
|
||||||
|
h.preTags[i] = []byte(tag)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
h.preTags = [][]byte{[]byte(DefaultPreTag)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := params[PostTagsKey]; ok {
|
||||||
|
tags := []string{}
|
||||||
|
if err := json.Unmarshal([]byte(value), &tags); err != nil {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal post_tags as string list failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(tags) == 0 {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("post_tags cannot be empty list")
|
||||||
|
}
|
||||||
|
h.postTags = make([][]byte, len(tags))
|
||||||
|
for i, tag := range tags {
|
||||||
|
h.postTags[i] = []byte(tag)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
h.postTags = [][]byte{[]byte(DefaultPostTag)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// set fragment config
|
||||||
|
if value, ok := params[FragmentSizeKey]; ok {
|
||||||
|
fragmentSize, err := strconv.ParseInt(value, 10, 64)
|
||||||
|
if err != nil || fragmentSize <= 0 {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("invalid fragment_size: %s", value)
|
||||||
|
}
|
||||||
|
h.options.FragmentSize = fragmentSize
|
||||||
|
} else {
|
||||||
|
h.options.FragmentSize = DefaultFragmentSize
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := params[FragmentNumKey]; ok {
|
||||||
|
fragmentNum, err := strconv.ParseInt(value, 10, 64)
|
||||||
|
if err != nil || fragmentNum < 0 {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("invalid num_of_fragments: %s", value)
|
||||||
|
}
|
||||||
|
h.options.NumOfFragments = fragmentNum
|
||||||
|
} else {
|
||||||
|
h.options.NumOfFragments = DefaultFragmentNum
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := params[FragmentOffsetKey]; ok {
|
||||||
|
fragmentOffset, err := strconv.ParseInt(value, 10, 64)
|
||||||
|
if err != nil || fragmentOffset < 0 {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("invalid fragment_offset: %s", value)
|
||||||
|
}
|
||||||
|
h.options.FragmentOffset = fragmentOffset
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := params[HighlightSearchTextKey]; ok {
|
||||||
|
enable, err := strconv.ParseBool(value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight_search_text as bool failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
h.highlightSearch = enable
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, ok := params[HighlightQueryKey]; ok {
|
||||||
|
queries := []any{}
|
||||||
|
if err := json.Unmarshal([]byte(value), &queries); err != nil {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries as json array failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, query := range queries {
|
||||||
|
m, ok := query.(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: item in array is not json object")
|
||||||
|
}
|
||||||
|
|
||||||
|
text, ok := m["text"]
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: must set `text` in query")
|
||||||
|
}
|
||||||
|
|
||||||
|
textStr, ok := text.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: `text` must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
t, ok := m["type"]
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: must set `type` in query")
|
||||||
|
}
|
||||||
|
|
||||||
|
typeStr, ok := t.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: `type` must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
typeEnum, ok := querypb.HighlightQueryType_value[typeStr]
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: invalid highlight query type: %s", typeStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, ok := m["field"]
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: must set `field` in query")
|
||||||
|
}
|
||||||
|
fieldStr, ok := f.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, merr.WrapErrParameterInvalidMsg("unmarshal highlight queries failed: `field` must be string")
|
||||||
|
}
|
||||||
|
|
||||||
|
h.queries = append(h.queries, &highlightQuery{
|
||||||
|
text: textStr,
|
||||||
|
highlightType: querypb.HighlightQueryType(typeEnum),
|
||||||
|
fieldName: fieldStr,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type lexicalHighlightOperator struct {
|
||||||
|
tasks []*highlightTask
|
||||||
|
fieldSchemas []*schemapb.FieldSchema
|
||||||
|
lbPolicy shardclient.LBPolicy
|
||||||
|
scheduler *taskScheduler
|
||||||
|
|
||||||
|
collectionName string
|
||||||
|
collectionID int64
|
||||||
|
dbName string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLexicalHighlightOperator(t *searchTask, tasks []*highlightTask) (operator, error) {
|
||||||
|
return &lexicalHighlightOperator{
|
||||||
|
tasks: tasks,
|
||||||
|
lbPolicy: t.lb,
|
||||||
|
scheduler: t.node.(*Proxy).sched,
|
||||||
|
fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema),
|
||||||
|
collectionName: t.request.CollectionName,
|
||||||
|
collectionID: t.CollectionID,
|
||||||
|
dbName: t.request.DbName,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (op *lexicalHighlightOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||||
|
result := inputs[0].(*milvuspb.SearchResults)
|
||||||
|
datas := result.Results.GetFieldsData()
|
||||||
|
req := &querypb.GetHighlightRequest{
|
||||||
|
Topks: result.GetResults().GetTopks(),
|
||||||
|
Tasks: lo.Map(op.tasks, func(task *highlightTask, _ int) *querypb.HighlightTask { return task.HighlightTask }),
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, task := range req.GetTasks() {
|
||||||
|
textFieldDatas, ok := lo.Find(datas, func(data *schemapb.FieldData) bool { return data.FieldId == task.GetFieldId() })
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("get highlight failed, text field not in output field %s: %d", task.GetFieldName(), task.GetFieldId())
|
||||||
|
}
|
||||||
|
texts := textFieldDatas.GetScalars().GetStringData().GetData()
|
||||||
|
task.Texts = append(task.Texts, texts...)
|
||||||
|
task.CorpusTextNum = int64(len(texts))
|
||||||
|
field, ok := lo.Find(op.fieldSchemas, func(schema *schemapb.FieldSchema) bool {
|
||||||
|
return schema.GetFieldID() == task.GetFieldId()
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("get highlight failed, field not found in schema %s: %d", task.GetFieldName(), task.GetFieldId())
|
||||||
|
}
|
||||||
|
|
||||||
|
// if use multi analyzer
|
||||||
|
// get analyzer field data
|
||||||
|
helper := typeutil.CreateFieldSchemaHelper(field)
|
||||||
|
if v, ok := helper.GetMultiAnalyzerParams(); ok {
|
||||||
|
params := map[string]any{}
|
||||||
|
err := json.Unmarshal([]byte(v), ¶ms)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Errorf("get highlight failed, get invalid multi analyzer params-: %v", err)
|
||||||
|
}
|
||||||
|
analyzerField, ok := params["by_field"]
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("get highlight failed, get invalid multi analyzer params, no by_field")
|
||||||
|
}
|
||||||
|
|
||||||
|
analyzerFieldDatas, ok := lo.Find(datas, func(data *schemapb.FieldData) bool { return data.FieldName == analyzerField.(string) })
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.Errorf("get highlight failed, analyzer field not in output field")
|
||||||
|
}
|
||||||
|
task.AnalyzerNames = append(task.AnalyzerNames, analyzerFieldDatas.GetScalars().GetStringData().GetData()...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
task := &HighlightTask{
|
||||||
|
ctx: ctx,
|
||||||
|
lb: op.lbPolicy,
|
||||||
|
Condition: NewTaskCondition(ctx),
|
||||||
|
GetHighlightRequest: req,
|
||||||
|
collectionName: op.collectionName,
|
||||||
|
collectionID: op.collectionID,
|
||||||
|
dbName: op.dbName,
|
||||||
|
}
|
||||||
|
if err := op.scheduler.dqQueue.Enqueue(task); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := task.WaitToFinish(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rowNum := len(result.Results.GetScores())
|
||||||
|
HighlightResults := []*commonpb.HighlightResult{}
|
||||||
|
if rowNum != 0 {
|
||||||
|
rowDatas := lo.Map(task.result.Results, func(result *querypb.HighlightResult, i int) *commonpb.HighlightData {
|
||||||
|
return buildStringFragments(op.tasks[i/rowNum], i%rowNum, result.GetFragments())
|
||||||
|
})
|
||||||
|
|
||||||
|
for i, task := range req.GetTasks() {
|
||||||
|
HighlightResults = append(HighlightResults, &commonpb.HighlightResult{
|
||||||
|
FieldName: task.GetFieldName(),
|
||||||
|
Datas: rowDatas[i*rowNum : (i+1)*rowNum],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result.Results.HighlightResults = HighlightResults
|
||||||
|
return []any{result}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildStringFragments(task *highlightTask, idx int, frags []*querypb.HighlightFragment) *commonpb.HighlightData {
|
||||||
|
startOffset := int(task.GetSearchTextNum()) + len(task.Queries)
|
||||||
|
text := []rune(task.Texts[startOffset+idx])
|
||||||
|
preTagsNum := len(task.preTags)
|
||||||
|
postTagsNum := len(task.postTags)
|
||||||
|
result := &commonpb.HighlightData{Fragments: make([]string, 0)}
|
||||||
|
for _, frag := range frags {
|
||||||
|
var fragString strings.Builder
|
||||||
|
cursor := int(frag.GetStartOffset())
|
||||||
|
for i := 0; i < len(frag.GetOffsets())/2; i++ {
|
||||||
|
startOffset := int(frag.Offsets[i<<1])
|
||||||
|
endOffset := int(frag.Offsets[(i<<1)+1])
|
||||||
|
if cursor < startOffset {
|
||||||
|
fragString.WriteString(string(text[cursor:startOffset]))
|
||||||
|
}
|
||||||
|
fragString.WriteString(string(task.preTags[i%preTagsNum]))
|
||||||
|
fragString.WriteString(string(text[startOffset:endOffset]))
|
||||||
|
fragString.WriteString(string(task.postTags[i%postTagsNum]))
|
||||||
|
cursor = endOffset
|
||||||
|
}
|
||||||
|
if cursor < int(frag.GetEndOffset()) {
|
||||||
|
fragString.WriteString(string(text[cursor:frag.GetEndOffset()]))
|
||||||
|
}
|
||||||
|
result.Fragments = append(result.Fragments, fragString.String())
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@ -21,7 +21,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
|
||||||
"github.com/samber/lo"
|
"github.com/samber/lo"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
"go.opentelemetry.io/otel/trace"
|
"go.opentelemetry.io/otel/trace"
|
||||||
@ -30,16 +29,13 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/json"
|
|
||||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
|
||||||
"github.com/milvus-io/milvus/internal/types"
|
"github.com/milvus-io/milvus/internal/types"
|
||||||
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
"github.com/milvus-io/milvus/internal/util/function/rerank"
|
||||||
"github.com/milvus-io/milvus/internal/util/segcore"
|
"github.com/milvus-io/milvus/internal/util/segcore"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/log"
|
"github.com/milvus-io/milvus/pkg/v2/log"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/internalpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/v2/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/v2/util/paramtable"
|
||||||
@ -598,192 +594,8 @@ func (op *filterFieldOperator) run(ctx context.Context, span trace.Span, inputs
|
|||||||
return []any{result}, nil
|
return []any{result}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
|
||||||
PreTagsKey = "pre_tags"
|
|
||||||
PostTagsKey = "post_tags"
|
|
||||||
HighlightSearchTextKey = "highlight_search_data"
|
|
||||||
FragmentOffsetKey = "fragment_offset"
|
|
||||||
FragmentSizeKey = "fragment_size"
|
|
||||||
FragmentNumKey = "num_of_fragments"
|
|
||||||
DefaultFragmentSize = 100
|
|
||||||
DefaultFragmentNum = 1
|
|
||||||
DefaultPreTag = "<em>"
|
|
||||||
DefaultPostTag = "</em>"
|
|
||||||
)
|
|
||||||
|
|
||||||
type highlightTask struct {
|
|
||||||
*querypb.HighlightTask
|
|
||||||
preTags [][]byte
|
|
||||||
postTags [][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type highlightOperator struct {
|
|
||||||
tasks []*highlightTask
|
|
||||||
fieldSchemas []*schemapb.FieldSchema
|
|
||||||
lbPolicy shardclient.LBPolicy
|
|
||||||
scheduler *taskScheduler
|
|
||||||
|
|
||||||
collectionName string
|
|
||||||
collectionID int64
|
|
||||||
dbName string
|
|
||||||
|
|
||||||
preTag []byte
|
|
||||||
postTag []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHighlightOperator(t *searchTask, _ map[string]any) (operator, error) {
|
func newHighlightOperator(t *searchTask, _ map[string]any) (operator, error) {
|
||||||
return &highlightOperator{
|
return t.highlighter.AsSearchPipelineOperator(t)
|
||||||
tasks: t.highlightTasks,
|
|
||||||
lbPolicy: t.lb,
|
|
||||||
scheduler: t.node.(*Proxy).sched,
|
|
||||||
fieldSchemas: typeutil.GetAllFieldSchemas(t.schema.CollectionSchema),
|
|
||||||
collectionName: t.request.CollectionName,
|
|
||||||
collectionID: t.CollectionID,
|
|
||||||
dbName: t.request.DbName,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func sliceByRune(s string, start, end int) string {
|
|
||||||
if start >= end {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
i, from, to := 0, 0, len(s)
|
|
||||||
for idx := range s {
|
|
||||||
if i == start {
|
|
||||||
from = idx
|
|
||||||
}
|
|
||||||
if i == end {
|
|
||||||
to = idx
|
|
||||||
break
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
|
|
||||||
return s[from:to]
|
|
||||||
}
|
|
||||||
|
|
||||||
// get slice texts according to fragment options
|
|
||||||
func getHighlightTexts(task *querypb.HighlightTask, datas []string) []string {
|
|
||||||
if task.GetOptions().GetNumOfFragments() == 0 {
|
|
||||||
return datas
|
|
||||||
}
|
|
||||||
|
|
||||||
results := make([]string, len(datas))
|
|
||||||
offset := int(task.GetOptions().GetFragmentOffset())
|
|
||||||
size := offset + int(task.GetOptions().GetFragmentSize()*task.GetOptions().GetNumOfFragments())
|
|
||||||
for i, text := range datas {
|
|
||||||
results[i] = sliceByRune(text, min(offset, len(text)), min(size, len(text)))
|
|
||||||
}
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
func (op *highlightOperator) run(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
|
||||||
result := inputs[0].(*milvuspb.SearchResults)
|
|
||||||
datas := result.Results.GetFieldsData()
|
|
||||||
req := &querypb.GetHighlightRequest{
|
|
||||||
Topks: result.GetResults().GetTopks(),
|
|
||||||
Tasks: lo.Map(op.tasks, func(task *highlightTask, _ int) *querypb.HighlightTask { return task.HighlightTask }),
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, task := range req.GetTasks() {
|
|
||||||
textFieldDatas, ok := lo.Find(datas, func(data *schemapb.FieldData) bool { return data.FieldId == task.GetFieldId() })
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("get highlight failed, text field not in output field %s: %d", task.GetFieldName(), task.GetFieldId())
|
|
||||||
}
|
|
||||||
texts := getHighlightTexts(task, textFieldDatas.GetScalars().GetStringData().GetData())
|
|
||||||
task.Texts = append(task.Texts, texts...)
|
|
||||||
task.CorpusTextNum = int64(len(texts))
|
|
||||||
field, ok := lo.Find(op.fieldSchemas, func(schema *schemapb.FieldSchema) bool {
|
|
||||||
return schema.GetFieldID() == task.GetFieldId()
|
|
||||||
})
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("get highlight failed, field not found in schema %s: %d", task.GetFieldName(), task.GetFieldId())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if use multi analyzer
|
|
||||||
// get analyzer field data
|
|
||||||
helper := typeutil.CreateFieldSchemaHelper(field)
|
|
||||||
if v, ok := helper.GetMultiAnalyzerParams(); ok {
|
|
||||||
params := map[string]any{}
|
|
||||||
err := json.Unmarshal([]byte(v), ¶ms)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Errorf("get highlight failed, get invalid multi analyzer params-: %v", err)
|
|
||||||
}
|
|
||||||
analyzerField, ok := params["by_field"]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("get highlight failed, get invalid multi analyzer params, no by_field")
|
|
||||||
}
|
|
||||||
|
|
||||||
analyzerFieldDatas, ok := lo.Find(datas, func(data *schemapb.FieldData) bool { return data.FieldName == analyzerField.(string) })
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("get highlight failed, analyzer field not in output field")
|
|
||||||
}
|
|
||||||
task.AnalyzerNames = append(task.AnalyzerNames, analyzerFieldDatas.GetScalars().GetStringData().GetData()...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
task := &HighlightTask{
|
|
||||||
ctx: ctx,
|
|
||||||
lb: op.lbPolicy,
|
|
||||||
Condition: NewTaskCondition(ctx),
|
|
||||||
GetHighlightRequest: req,
|
|
||||||
collectionName: op.collectionName,
|
|
||||||
collectionID: op.collectionID,
|
|
||||||
dbName: op.dbName,
|
|
||||||
}
|
|
||||||
if err := op.scheduler.dqQueue.Enqueue(task); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := task.WaitToFinish(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rowNum := len(result.Results.GetScores())
|
|
||||||
HighlightResults := []*commonpb.HighlightResult{}
|
|
||||||
if rowNum != 0 {
|
|
||||||
rowDatas := lo.Map(task.result.Results, func(result *querypb.HighlightResult, i int) *commonpb.HighlightData {
|
|
||||||
return buildStringFragments(op.tasks[i/rowNum], i%rowNum, result.GetFragments())
|
|
||||||
})
|
|
||||||
|
|
||||||
for i, task := range req.GetTasks() {
|
|
||||||
HighlightResults = append(HighlightResults, &commonpb.HighlightResult{
|
|
||||||
FieldName: task.GetFieldName(),
|
|
||||||
Datas: rowDatas[i*rowNum : (i+1)*rowNum],
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Results.HighlightResults = HighlightResults
|
|
||||||
return []any{result}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildStringFragments(task *highlightTask, idx int, frags []*querypb.HighlightFragment) *commonpb.HighlightData {
|
|
||||||
bytes := []byte(task.Texts[int(task.GetSearchTextNum())+idx])
|
|
||||||
preTagsNum := len(task.preTags)
|
|
||||||
postTagsNum := len(task.postTags)
|
|
||||||
result := &commonpb.HighlightData{Fragments: make([]string, 0)}
|
|
||||||
for _, frag := range frags {
|
|
||||||
fragBytes := []byte{}
|
|
||||||
cursor := int(frag.GetStartOffset())
|
|
||||||
for i := 0; i < len(frag.GetOffsets())/2; i++ {
|
|
||||||
startOffset := int(frag.Offsets[i<<1])
|
|
||||||
endOffset := int(frag.Offsets[(i<<1)+1])
|
|
||||||
if cursor < startOffset {
|
|
||||||
fragBytes = append(fragBytes, bytes[cursor:startOffset]...)
|
|
||||||
}
|
|
||||||
fragBytes = append(fragBytes, task.preTags[i%preTagsNum]...)
|
|
||||||
fragBytes = append(fragBytes, bytes[startOffset:endOffset]...)
|
|
||||||
fragBytes = append(fragBytes, task.postTags[i%postTagsNum]...)
|
|
||||||
cursor = endOffset
|
|
||||||
}
|
|
||||||
if cursor < int(frag.GetEndOffset()) {
|
|
||||||
fragBytes = append(fragBytes, bytes[cursor:frag.GetEndOffset()]...)
|
|
||||||
}
|
|
||||||
result.Fragments = append(result.Fragments, string(fragBytes))
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeIDsFunc(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
func mergeIDsFunc(ctx context.Context, span trace.Span, inputs ...any) ([]any, error) {
|
||||||
@ -1318,7 +1130,7 @@ func newSearchPipeline(t *searchTask) (*pipeline, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(t.highlightTasks) > 0 {
|
if t.highlighter != nil {
|
||||||
err := p.AddNodes(t, highlightNode, filterFieldNode)
|
err := p.AddNodes(t, highlightNode, filterFieldNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@ -299,8 +299,8 @@ func (s *SearchPipelineSuite) TestHighlightOp() {
|
|||||||
DbName: "default",
|
DbName: "default",
|
||||||
}
|
}
|
||||||
|
|
||||||
highlightTasks := []*highlightTask{
|
highlightTasks := map[int64]*highlightTask{
|
||||||
{
|
100: {
|
||||||
HighlightTask: &querypb.HighlightTask{
|
HighlightTask: &querypb.HighlightTask{
|
||||||
Texts: []string{"target text"},
|
Texts: []string{"target text"},
|
||||||
FieldName: testVarCharField,
|
FieldName: testVarCharField,
|
||||||
@ -314,7 +314,9 @@ func (s *SearchPipelineSuite) TestHighlightOp() {
|
|||||||
mockLb := shardclient.NewMockLBPolicy(s.T())
|
mockLb := shardclient.NewMockLBPolicy(s.T())
|
||||||
searchTask := &searchTask{
|
searchTask := &searchTask{
|
||||||
node: proxy,
|
node: proxy,
|
||||||
highlightTasks: highlightTasks,
|
highlighter: &LexicalHighlighter{
|
||||||
|
tasks: highlightTasks,
|
||||||
|
},
|
||||||
lb: mockLb,
|
lb: mockLb,
|
||||||
schema: newSchemaInfo(schema),
|
schema: newSchemaInfo(schema),
|
||||||
request: req,
|
request: req,
|
||||||
|
|||||||
@ -3204,7 +3204,7 @@ func (t *HighlightTask) OnEnqueue() error {
|
|||||||
if t.Base == nil {
|
if t.Base == nil {
|
||||||
t.Base = commonpbutil.NewMsgBase()
|
t.Base = commonpbutil.NewMsgBase()
|
||||||
}
|
}
|
||||||
t.Base.MsgType = commonpb.MsgType_RunAnalyzer
|
t.Base.MsgType = commonpb.MsgType_Undefined
|
||||||
t.Base.SourceID = paramtable.GetNodeID()
|
t.Base.SourceID = paramtable.GetNodeID()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -17,7 +17,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/milvuspb"
|
||||||
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
|
||||||
"github.com/milvus-io/milvus/internal/json"
|
|
||||||
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
"github.com/milvus-io/milvus/internal/parser/planparserv2"
|
||||||
"github.com/milvus-io/milvus/internal/proxy/accesslog"
|
"github.com/milvus-io/milvus/internal/proxy/accesslog"
|
||||||
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
"github.com/milvus-io/milvus/internal/proxy/shardclient"
|
||||||
@ -82,7 +81,7 @@ type searchTask struct {
|
|||||||
translatedOutputFields []string
|
translatedOutputFields []string
|
||||||
userOutputFields []string
|
userOutputFields []string
|
||||||
userDynamicFields []string
|
userDynamicFields []string
|
||||||
highlightTasks []*highlightTask
|
highlighter Highlighter
|
||||||
resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults]
|
resultBuf *typeutil.ConcurrentSet[*internalpb.SearchResults]
|
||||||
|
|
||||||
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
|
partitionIDsSet *typeutil.ConcurrentSet[UniqueID]
|
||||||
@ -478,9 +477,6 @@ func (t *searchTask) initAdvancedSearchRequest(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
internalSubReq.FieldId = queryInfo.GetQueryFieldId()
|
||||||
if err := t.addHighlightTask(t.request.GetHighlighter(), internalSubReq.GetMetricType(), internalSubReq.FieldId, subReq.GetPlaceholderGroup(), internalSubReq.GetAnalyzerName()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
queryFieldIDs = append(queryFieldIDs, internalSubReq.FieldId)
|
queryFieldIDs = append(queryFieldIDs, internalSubReq.FieldId)
|
||||||
// set PartitionIDs for sub search
|
// set PartitionIDs for sub search
|
||||||
@ -586,17 +582,12 @@ func (t *searchTask) getBM25SearchTexts(placeholder []byte) ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *searchTask) createLexicalHighlighter(highlighter *commonpb.Highlighter, metricType string, annsField int64, placeholder []byte, analyzerName string) error {
|
func (t *searchTask) createLexicalHighlighter(highlighter *commonpb.Highlighter, metricType string, annsField int64, placeholder []byte, analyzerName string) error {
|
||||||
task := &highlightTask{
|
h, err := NewLexicalHighlighter(highlighter)
|
||||||
HighlightTask: &querypb.HighlightTask{
|
if err != nil {
|
||||||
Options: &querypb.HighlightOptions{
|
return err
|
||||||
FragmentSize: DefaultFragmentSize,
|
|
||||||
NumOfFragments: DefaultFragmentNum,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
t.highlighter = h
|
||||||
params := funcutil.KeyValuePair2Map(highlighter.GetParams())
|
if h.highlightSearch {
|
||||||
|
|
||||||
if metricType != metric.BM25 {
|
if metricType != metric.BM25 {
|
||||||
return merr.WrapErrParameterInvalidMsg(`Search highlight only support with metric type "BM25" but was: %s`, t.SearchRequest.GetMetricType())
|
return merr.WrapErrParameterInvalidMsg(`Search highlight only support with metric type "BM25" but was: %s`, t.SearchRequest.GetMetricType())
|
||||||
}
|
}
|
||||||
@ -604,97 +595,15 @@ func (t *searchTask) createLexicalHighlighter(highlighter *commonpb.Highlighter,
|
|||||||
if !ok {
|
if !ok {
|
||||||
return merr.WrapErrServiceInternal(`Search with highlight failed, input field of BM25 annsField not found`)
|
return merr.WrapErrServiceInternal(`Search with highlight failed, input field of BM25 annsField not found`)
|
||||||
}
|
}
|
||||||
task.FieldId = function.InputFieldIds[0]
|
fieldId := function.InputFieldIds[0]
|
||||||
task.FieldName = function.InputFieldNames[0]
|
fieldName := function.InputFieldNames[0]
|
||||||
|
// set bm25 search text as highlight search texts
|
||||||
if value, ok := params[HighlightSearchTextKey]; ok {
|
|
||||||
enable, err := strconv.ParseBool(value)
|
|
||||||
if err != nil {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("unmarshal highlight_search_data as bool failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// now only support highlight with search
|
|
||||||
// so skip if highlight search not enable.
|
|
||||||
if !enable {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set pre_tags and post_tags
|
|
||||||
if value, ok := params[PreTagsKey]; ok {
|
|
||||||
tags := []string{}
|
|
||||||
if err := json.Unmarshal([]byte(value), &tags); err != nil {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("unmarshal pre_tags as string list failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(tags) == 0 {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("pre_tags cannot be empty list")
|
|
||||||
}
|
|
||||||
|
|
||||||
task.preTags = make([][]byte, len(tags))
|
|
||||||
for i, tag := range tags {
|
|
||||||
task.preTags[i] = []byte(tag)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
task.preTags = [][]byte{[]byte(DefaultPreTag)}
|
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := params[PostTagsKey]; ok {
|
|
||||||
tags := []string{}
|
|
||||||
if err := json.Unmarshal([]byte(value), &tags); err != nil {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("unmarshal post_tags as string list failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(tags) == 0 {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("post_tags cannot be empty list")
|
|
||||||
}
|
|
||||||
task.postTags = make([][]byte, len(tags))
|
|
||||||
for i, tag := range tags {
|
|
||||||
task.postTags[i] = []byte(tag)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
task.postTags = [][]byte{[]byte(DefaultPostTag)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// set fragment config
|
|
||||||
if value, ok := params[FragmentSizeKey]; ok {
|
|
||||||
fragmentSize, err := strconv.ParseInt(value, 10, 64)
|
|
||||||
if err != nil || fragmentSize <= 0 {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("invalid fragment_size: %s", value)
|
|
||||||
}
|
|
||||||
task.Options.FragmentSize = fragmentSize
|
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := params[FragmentNumKey]; ok {
|
|
||||||
fragmentNum, err := strconv.ParseInt(value, 10, 64)
|
|
||||||
if err != nil || fragmentNum <= 0 {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("invalid fragment_size: %s", value)
|
|
||||||
}
|
|
||||||
task.Options.NumOfFragments = fragmentNum
|
|
||||||
}
|
|
||||||
|
|
||||||
if value, ok := params[FragmentOffsetKey]; ok {
|
|
||||||
fragmentOffset, err := strconv.ParseInt(value, 10, 64)
|
|
||||||
if err != nil || fragmentOffset <= 0 {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("invalid fragment_size: %s", value)
|
|
||||||
}
|
|
||||||
task.Options.NumOfFragments = fragmentOffset
|
|
||||||
}
|
|
||||||
|
|
||||||
// set bm25 search text as query texts
|
|
||||||
texts, err := t.getBM25SearchTexts(placeholder)
|
texts, err := t.getBM25SearchTexts(placeholder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
return h.addTaskWithSearchText(fieldId, fieldName, analyzerName, texts)
|
||||||
task.Texts = texts
|
|
||||||
task.SearchTextNum = int64(len(texts))
|
|
||||||
if analyzerName != "" {
|
|
||||||
task.AnalyzerNames = []string{}
|
|
||||||
for i := 0; i < len(texts); i++ {
|
|
||||||
task.AnalyzerNames = append(task.AnalyzerNames, analyzerName)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
t.highlightTasks = append(t.highlightTasks, task)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -4939,9 +4939,12 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
|||||||
|
|
||||||
err := task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
|
err := task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(task.highlightTasks))
|
|
||||||
assert.Equal(t, int64(100), task.highlightTasks[0].FieldId)
|
h, ok := task.highlighter.(*LexicalHighlighter)
|
||||||
assert.Equal(t, "text_field", task.highlightTasks[0].FieldName)
|
require.True(t, ok)
|
||||||
|
require.Equal(t, 1, len(h.tasks))
|
||||||
|
assert.Equal(t, int64(100), h.tasks[100].FieldId)
|
||||||
|
assert.Equal(t, "text_field", h.tasks[100].FieldName)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Lexical highlight with custom tags", func(t *testing.T) {
|
t.Run("Lexical highlight with custom tags", func(t *testing.T) {
|
||||||
@ -4958,11 +4961,13 @@ func TestSearchTask_AddHighlightTask(t *testing.T) {
|
|||||||
|
|
||||||
err := task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
|
err := task.addHighlightTask(highlighter, metric.BM25, 101, placeholderBytes, "")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(task.highlightTasks))
|
|
||||||
assert.Equal(t, 1, len(task.highlightTasks[0].preTags))
|
h, ok := task.highlighter.(*LexicalHighlighter)
|
||||||
assert.Equal(t, []byte("<b>"), task.highlightTasks[0].preTags[0])
|
require.True(t, ok)
|
||||||
assert.Equal(t, 1, len(task.highlightTasks[0].postTags))
|
assert.Equal(t, 1, len(h.preTags))
|
||||||
assert.Equal(t, []byte("</b>"), task.highlightTasks[0].postTags[0])
|
assert.Equal(t, []byte("<b>"), h.preTags[0])
|
||||||
|
assert.Equal(t, 1, len(h.postTags))
|
||||||
|
assert.Equal(t, []byte("</b>"), h.postTags[0])
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("lexical highlight with wrong metric type", func(t *testing.T) {
|
t.Run("lexical highlight with wrong metric type", func(t *testing.T) {
|
||||||
|
|||||||
@ -5550,7 +5550,7 @@ func TestHighlightTask(t *testing.T) {
|
|||||||
err := task.OnEnqueue()
|
err := task.OnEnqueue()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.NotNil(t, task.Base)
|
assert.NotNil(t, task.Base)
|
||||||
assert.Equal(t, commonpb.MsgType_RunAnalyzer, task.Base.MsgType)
|
assert.Equal(t, commonpb.MsgType_Undefined, task.Base.MsgType)
|
||||||
assert.Equal(t, paramtable.GetNodeID(), task.Base.SourceID)
|
assert.Equal(t, paramtable.GetNodeID(), task.Base.SourceID)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@ -1006,8 +1006,8 @@ func (sd *shardDelegator) DropIndex(ctx context.Context, req *querypb.DropIndexR
|
|||||||
func (sd *shardDelegator) GetHighlight(ctx context.Context, req *querypb.GetHighlightRequest) ([]*querypb.HighlightResult, error) {
|
func (sd *shardDelegator) GetHighlight(ctx context.Context, req *querypb.GetHighlightRequest) ([]*querypb.HighlightResult, error) {
|
||||||
result := []*querypb.HighlightResult{}
|
result := []*querypb.HighlightResult{}
|
||||||
for _, task := range req.GetTasks() {
|
for _, task := range req.GetTasks() {
|
||||||
if len(task.GetTexts()) != int(task.GetSearchTextNum()+task.GetCorpusTextNum()) {
|
if len(task.GetTexts()) != int(task.GetSearchTextNum()+task.GetCorpusTextNum())+len(task.GetQueries()) {
|
||||||
return nil, errors.Errorf("package highlight texts error, num of texts not equal the expected num %d:%d", len(task.GetTexts()), task.GetSearchTextNum()+task.GetCorpusTextNum())
|
return nil, errors.Errorf("package highlight texts error, num of texts not equal the expected num %d:%d", len(task.GetTexts()), int(task.GetSearchTextNum()+task.GetCorpusTextNum())+len(task.GetQueries()))
|
||||||
}
|
}
|
||||||
analyzer, ok := sd.analyzerRunners[task.GetFieldId()]
|
analyzer, ok := sd.analyzerRunners[task.GetFieldId()]
|
||||||
if !ok {
|
if !ok {
|
||||||
@ -1032,24 +1032,52 @@ func (sd *shardDelegator) GetHighlight(ctx context.Context, req *querypb.GetHigh
|
|||||||
|
|
||||||
// analyze result of search text
|
// analyze result of search text
|
||||||
searchResults := results[0:task.SearchTextNum]
|
searchResults := results[0:task.SearchTextNum]
|
||||||
|
// analyze result of query text
|
||||||
|
queryResults := results[task.SearchTextNum : task.SearchTextNum+int64(len(task.Queries))]
|
||||||
// analyze result of corpus text
|
// analyze result of corpus text
|
||||||
corpusResults := results[task.SearchTextNum:]
|
corpusStartOffset := int(task.SearchTextNum) + len(task.Queries)
|
||||||
corpusIdx := 0
|
corpusResults := results[corpusStartOffset:]
|
||||||
for i, tokens := range searchResults {
|
|
||||||
tokenSet := typeutil.NewSet[string]()
|
// query for all corpus texts
|
||||||
|
// only support text match now
|
||||||
|
// build match set for all analyze result of query text
|
||||||
|
// TODO: support more query types
|
||||||
|
queryTokenSet := typeutil.NewSet[string]()
|
||||||
|
for _, tokens := range queryResults {
|
||||||
for _, token := range tokens {
|
for _, token := range tokens {
|
||||||
|
queryTokenSet.Insert(token.GetToken())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
corpusIdx := 0
|
||||||
|
for i := range len(topks) {
|
||||||
|
tokenSet := typeutil.NewSet[string]()
|
||||||
|
if len(searchResults) > i {
|
||||||
|
for _, token := range searchResults[i] {
|
||||||
tokenSet.Insert(token.GetToken())
|
tokenSet.Insert(token.GetToken())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for j := 0; j < int(topks[i]); j++ {
|
for j := 0; j < int(topks[i]); j++ {
|
||||||
spans := SpanList{}
|
spans := SpanList{}
|
||||||
for _, token := range corpusResults[corpusIdx] {
|
for _, token := range corpusResults[corpusIdx] {
|
||||||
if tokenSet.Contain(token.GetToken()) {
|
if tokenSet.Contain(token.GetToken()) || queryTokenSet.Contain(token.GetToken()) {
|
||||||
spans = append(spans, Span{token.GetStartOffset(), token.GetEndOffset()})
|
spans = append(spans, Span{token.GetStartOffset(), token.GetEndOffset()})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
spans = mergeOffsets(spans)
|
spans = mergeOffsets(spans)
|
||||||
frags := fetchFragmentsFromOffsets(task.Texts[int(task.SearchTextNum)+corpusIdx], spans, task.GetOptions().GetFragmentSize(), task.GetOptions().GetNumOfFragments())
|
|
||||||
|
// Convert byte offsets from analyzer to rune (character) offsets
|
||||||
|
corpusText := task.Texts[corpusStartOffset+corpusIdx]
|
||||||
|
err := bytesOffsetToRuneOffset(corpusText, spans)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
frags := fetchFragmentsFromOffsets(corpusText, spans,
|
||||||
|
task.GetOptions().GetFragmentOffset(),
|
||||||
|
task.GetOptions().GetFragmentSize(),
|
||||||
|
task.GetOptions().GetNumOfFragments())
|
||||||
result = append(result, &querypb.HighlightResult{Fragments: frags})
|
result = append(result, &querypb.HighlightResult{Fragments: frags})
|
||||||
corpusIdx++
|
corpusIdx++
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"sort"
|
"sort"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/cockroachdb/errors"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
@ -14,6 +15,7 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/planpb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
"github.com/milvus-io/milvus/pkg/v2/proto/querypb"
|
||||||
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
"github.com/milvus-io/milvus/pkg/v2/util/merr"
|
||||||
|
"github.com/milvus-io/milvus/pkg/v2/util/typeutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BuildSparseFieldData(field *schemapb.FieldSchema, sparseArray *schemapb.SparseFloatArray) *schemapb.FieldData {
|
func BuildSparseFieldData(field *schemapb.FieldSchema, sparseArray *schemapb.SparseFloatArray) *schemapb.FieldData {
|
||||||
@ -104,53 +106,78 @@ func mergeOffsets(input SpanList) SpanList {
|
|||||||
return offsets
|
return offsets
|
||||||
}
|
}
|
||||||
|
|
||||||
func fetchFragmentsFromOffsets(text string, span SpanList, fragmentSize int64, numOfFragments int64) []*querypb.HighlightFragment {
|
func bytesOffsetToRuneOffset(text string, spans SpanList) error {
|
||||||
|
byteOffsetSet := typeutil.NewSet[int64]()
|
||||||
|
for _, span := range spans {
|
||||||
|
byteOffsetSet.Insert(span[0])
|
||||||
|
byteOffsetSet.Insert(span[1])
|
||||||
|
}
|
||||||
|
offsetMap := map[int64]int64{0: 0, int64(len(text)): int64(utf8.RuneCountInString(text))}
|
||||||
|
|
||||||
|
cnt := int64(0)
|
||||||
|
for i := range text {
|
||||||
|
if byteOffsetSet.Contain(int64(i)) {
|
||||||
|
offsetMap[int64(i)] = cnt
|
||||||
|
}
|
||||||
|
cnt++
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert spans from byte offsets to rune offsets
|
||||||
|
for i, span := range spans {
|
||||||
|
startOffset, ok := offsetMap[span[0]]
|
||||||
|
if !ok {
|
||||||
|
return errors.Errorf("start offset: %d not found (text: %d bytes)", span[0], len(text))
|
||||||
|
}
|
||||||
|
endOffset, ok := offsetMap[span[1]]
|
||||||
|
if !ok {
|
||||||
|
return errors.Errorf("end offset: %d not found (text: %d bytes)", span[1], len(text))
|
||||||
|
}
|
||||||
|
spans[i][0] = startOffset
|
||||||
|
spans[i][1] = endOffset
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func fetchFragmentsFromOffsets(text string, spans SpanList, fragmentOffset int64, fragmentSize int64, numOfFragments int64) []*querypb.HighlightFragment {
|
||||||
result := make([]*querypb.HighlightFragment, 0)
|
result := make([]*querypb.HighlightFragment, 0)
|
||||||
endPosition := int(fragmentSize)
|
textRuneLen := int64(utf8.RuneCountInString(text))
|
||||||
nowOffset := 0
|
|
||||||
frag := &querypb.HighlightFragment{
|
|
||||||
StartOffset: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
next := func() {
|
var frag *querypb.HighlightFragment = nil
|
||||||
endPosition += int(fragmentSize)
|
next := func(span *Span) bool {
|
||||||
frag.EndOffset = int64(nowOffset)
|
startOffset := max(0, span[0]-fragmentOffset)
|
||||||
|
endOffset := min(max(span[1], startOffset+fragmentSize), textRuneLen)
|
||||||
|
if frag != nil {
|
||||||
result = append(result, frag)
|
result = append(result, frag)
|
||||||
|
}
|
||||||
|
if len(result) >= int(numOfFragments) {
|
||||||
|
frag = nil
|
||||||
|
return false
|
||||||
|
}
|
||||||
frag = &querypb.HighlightFragment{
|
frag = &querypb.HighlightFragment{
|
||||||
StartOffset: int64(nowOffset),
|
StartOffset: startOffset,
|
||||||
|
EndOffset: endOffset,
|
||||||
|
Offsets: []int64{span[0], span[1]},
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
cursor := 0
|
for i, span := range spans {
|
||||||
spanNum := len(span)
|
if frag == nil || span[0] > frag.EndOffset {
|
||||||
for i, r := range text {
|
if !next(&span) {
|
||||||
nowOffset += utf8.RuneLen(r)
|
|
||||||
|
|
||||||
// append if span was included in current fragment
|
|
||||||
for ; cursor < spanNum && span[cursor][1] <= int64(nowOffset); cursor++ {
|
|
||||||
if span[cursor][0] >= frag.StartOffset {
|
|
||||||
frag.Offsets = append(frag.Offsets, span[cursor][0], span[cursor][1])
|
|
||||||
} else {
|
|
||||||
// if some span cross fragment start, append the part in current fragment
|
|
||||||
frag.Offsets = append(frag.Offsets, frag.StartOffset, span[cursor][1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if i >= endPosition {
|
|
||||||
// if some span cross fragment end, append the part in current fragment
|
|
||||||
if cursor < spanNum && span[cursor][0] < int64(nowOffset) {
|
|
||||||
frag.Offsets = append(frag.Offsets, span[cursor][0], int64(nowOffset))
|
|
||||||
}
|
|
||||||
next()
|
|
||||||
// skip all if no span remain or get enough num of fragments
|
|
||||||
if cursor >= spanNum || int64(len(result)) >= numOfFragments {
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// append rune offset to fragment
|
||||||
|
frag.Offsets = append(frag.Offsets, spans[i][0], spans[i][1])
|
||||||
|
// extend fragment end offset if this span goes beyond current boundary
|
||||||
|
if span[1] > frag.EndOffset {
|
||||||
|
frag.EndOffset = span[1]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if nowOffset > int(frag.StartOffset) {
|
if frag != nil {
|
||||||
next()
|
result = append(result, frag)
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|||||||
38
internal/querynodev2/delegator/util_test.go
Normal file
38
internal/querynodev2/delegator/util_test.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
// Licensed to the LF AI & Data foundation under one
|
||||||
|
// or more contributor license agreements. See the NOTICE file
|
||||||
|
// distributed with this work for additional information
|
||||||
|
// regarding copyright ownership. The ASF licenses this file
|
||||||
|
// to you under the Apache License, Version 2.0 (the
|
||||||
|
// "License"); you may not use this file except in compliance
|
||||||
|
// with the License. You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
package delegator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBytesOffsetToRuneOffset(t *testing.T) {
|
||||||
|
// test with chinese
|
||||||
|
text := "你好世界" // 12 bytes, 4 runes
|
||||||
|
spans := SpanList{{0, 6}, {6, 12}}
|
||||||
|
err := bytesOffsetToRuneOffset(text, spans)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, SpanList{{0, 2}, {2, 4}}, spans)
|
||||||
|
|
||||||
|
// test with emoji
|
||||||
|
text = "Hello👋World" // 15 bytes, 11 runes
|
||||||
|
spans = SpanList{{0, 5}, {5, 9}, {9, 14}}
|
||||||
|
err = bytesOffsetToRuneOffset(text, spans)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, SpanList{{0, 5}, {5, 6}, {6, 11}}, spans)
|
||||||
|
}
|
||||||
@ -1020,19 +1020,32 @@ message HighlightOptions{
|
|||||||
int64 num_of_fragments = 3;
|
int64 num_of_fragments = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum HighlightQueryType{
|
||||||
|
TextMatch = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
message HighlightQuery{
|
||||||
|
HighlightQueryType type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// HighlightTask fetch highlight for all queries at one field
|
// HighlightTask fetch highlight for all queries at one field
|
||||||
// len(texts) == search_text_num + corpus_text_num
|
// search_text_num/search_num == len(topks) == nq
|
||||||
|
// corpus_text_num == sum(topks) == len(search_results)
|
||||||
message HighlightTask{
|
message HighlightTask{
|
||||||
string field_name = 1;
|
string field_name = 1;
|
||||||
int64 field_id = 2;
|
int64 field_id = 2;
|
||||||
|
// len(texts) = search_text_num + corpus_text_num + len(queries);
|
||||||
|
// text = search_text...corpus_text...query_text
|
||||||
repeated string texts = 3;
|
repeated string texts = 3;
|
||||||
repeated string analyzer_names = 4; // used if field with multi-analyzer
|
repeated string analyzer_names = 4; // used if field with multi-analyzer
|
||||||
|
|
||||||
int64 search_text_num = 5;
|
int64 search_text_num = 5;
|
||||||
int64 corpus_text_num = 6;
|
int64 corpus_text_num = 6;
|
||||||
HighlightOptions options = 7;
|
HighlightOptions options = 7;
|
||||||
|
repeated HighlightQuery queries = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get Lexical highlight from delegator
|
||||||
message GetHighlightRequest{
|
message GetHighlightRequest{
|
||||||
common.MsgBase base = 1;
|
common.MsgBase base = 1;
|
||||||
string channel = 2;
|
string channel = 2;
|
||||||
@ -1045,6 +1058,7 @@ message GetHighlightRequest{
|
|||||||
message HighlightFragment{
|
message HighlightFragment{
|
||||||
int64 start_offset = 1;
|
int64 start_offset = 1;
|
||||||
int64 end_offset = 2;
|
int64 end_offset = 2;
|
||||||
|
// char offset of the highlight terms in the fragment
|
||||||
repeated int64 offsets = 3;
|
repeated int64 offsets = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user