mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 01:28:27 +08:00
enhance: [2.5] Add param item for hybrid search requery policy (#44467)
Cherry-pick from master pr: #44466 related to #39757 --------- Signed-off-by: Congqi Xia <congqi.xia@zilliz.com>
This commit is contained in:
parent
cdcad7b1c7
commit
d251e102b6
@ -515,6 +515,9 @@ func rankSearchResultDataByGroup(ctx context.Context,
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init FieldsData
|
||||||
|
ret.Results.FieldsData = typeutil.PrepareResultFieldData(searchResults[0].GetResults().GetFieldsData(), limit)
|
||||||
|
|
||||||
totalCount := limit * groupSize
|
totalCount := limit * groupSize
|
||||||
if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil {
|
if err := setupIdListForSearchResult(ret, pkType, totalCount); err != nil {
|
||||||
return ret, err
|
return ret, err
|
||||||
@ -526,11 +529,18 @@ func rankSearchResultDataByGroup(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq)
|
accumulatedScores := make([]map[interface{}]*accumulateIDGroupVal, nq)
|
||||||
|
type dataLoc struct {
|
||||||
|
resultIdx int
|
||||||
|
offset int
|
||||||
|
}
|
||||||
|
pk2DataOffset := make([]map[any]dataLoc, nq)
|
||||||
for i := int64(0); i < nq; i++ {
|
for i := int64(0); i < nq; i++ {
|
||||||
accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal)
|
accumulatedScores[i] = make(map[interface{}]*accumulateIDGroupVal)
|
||||||
|
pk2DataOffset[i] = make(map[any]dataLoc)
|
||||||
}
|
}
|
||||||
|
|
||||||
groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType()
|
groupByDataType := searchResults[0].GetResults().GetGroupByFieldValue().GetType()
|
||||||
for _, result := range searchResults {
|
for ri, result := range searchResults {
|
||||||
scores := result.GetResults().GetScores()
|
scores := result.GetResults().GetScores()
|
||||||
start := 0
|
start := 0
|
||||||
// milvus has limits for the value range of nq and limit
|
// milvus has limits for the value range of nq and limit
|
||||||
@ -540,6 +550,7 @@ func rankSearchResultDataByGroup(ctx context.Context,
|
|||||||
for j := start; j < start+realTopK; j++ {
|
for j := start; j < start+realTopK; j++ {
|
||||||
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
|
id := typeutil.GetPK(result.GetResults().GetIds(), int64(j))
|
||||||
groupByVal := typeutil.GetData(result.GetResults().GetGroupByFieldValue(), j)
|
groupByVal := typeutil.GetData(result.GetResults().GetGroupByFieldValue(), j)
|
||||||
|
pk2DataOffset[i][id] = dataLoc{resultIdx: ri, offset: j}
|
||||||
if accumulatedScores[i][id] != nil {
|
if accumulatedScores[i][id] != nil {
|
||||||
accumulatedScores[i][id].accumulatedScore += scores[j]
|
accumulatedScores[i][id].accumulatedScore += scores[j]
|
||||||
} else {
|
} else {
|
||||||
@ -623,14 +634,16 @@ func rankSearchResultDataByGroup(ctx context.Context,
|
|||||||
returnedRowNum := 0
|
returnedRowNum := 0
|
||||||
for index := int(offset); index < len(groupList); index++ {
|
for index := int(offset); index < len(groupList); index++ {
|
||||||
group := groupList[index]
|
group := groupList[index]
|
||||||
for i, score := range group.scoreList {
|
for idx, score := range group.scoreList {
|
||||||
// idList and scoreList must have same length
|
// idList and scoreList must have same length
|
||||||
typeutil.AppendPKs(ret.Results.Ids, group.idList[i])
|
typeutil.AppendPKs(ret.Results.Ids, group.idList[idx])
|
||||||
if roundDecimal != -1 {
|
if roundDecimal != -1 {
|
||||||
multiplier := math.Pow(10.0, float64(roundDecimal))
|
multiplier := math.Pow(10.0, float64(roundDecimal))
|
||||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||||
}
|
}
|
||||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||||
|
loc := pk2DataOffset[i][group.idList[idx]]
|
||||||
|
typeutil.AppendFieldData(ret.Results.FieldsData, searchResults[loc.resultIdx].GetResults().GetFieldsData(), int64(loc.offset))
|
||||||
typeutil.AppendGroupByValue(ret.Results, group.groupVal, groupByDataType)
|
typeutil.AppendGroupByValue(ret.Results, group.groupVal, groupByDataType)
|
||||||
}
|
}
|
||||||
returnedRowNum += len(group.idList)
|
returnedRowNum += len(group.idList)
|
||||||
@ -699,23 +712,34 @@ func rankSearchResultDataByPk(ctx context.Context,
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init FieldsData
|
||||||
|
ret.Results.FieldsData = typeutil.PrepareResultFieldData(searchResults[0].GetResults().GetFieldsData(), limit)
|
||||||
|
|
||||||
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
if err := setupIdListForSearchResult(ret, pkType, limit); err != nil {
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// []map[id]score
|
// []map[id]score
|
||||||
accumulatedScores := make([]map[interface{}]float32, nq)
|
accumulatedScores := make([]map[interface{}]float32, nq)
|
||||||
|
type dataLoc struct {
|
||||||
|
resultIdx int
|
||||||
|
offset int64
|
||||||
|
}
|
||||||
|
pk2DataOffset := make([]map[any]dataLoc, nq)
|
||||||
for i := int64(0); i < nq; i++ {
|
for i := int64(0); i < nq; i++ {
|
||||||
accumulatedScores[i] = make(map[interface{}]float32)
|
accumulatedScores[i] = make(map[interface{}]float32)
|
||||||
|
pk2DataOffset[i] = make(map[any]dataLoc)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, result := range searchResults {
|
for ri, result := range searchResults {
|
||||||
scores := result.GetResults().GetScores()
|
scores := result.GetResults().GetScores()
|
||||||
start := int64(0)
|
start := int64(0)
|
||||||
for i := int64(0); i < nq; i++ {
|
for i := int64(0); i < nq; i++ {
|
||||||
realTopk := result.GetResults().Topks[i]
|
realTopk := result.GetResults().Topks[i]
|
||||||
for j := start; j < start+realTopk; j++ {
|
for j := start; j < start+realTopk; j++ {
|
||||||
id := typeutil.GetPK(result.GetResults().GetIds(), j)
|
id := typeutil.GetPK(result.GetResults().GetIds(), j)
|
||||||
|
|
||||||
|
pk2DataOffset[i][id] = dataLoc{resultIdx: ri, offset: j}
|
||||||
accumulatedScores[i][id] += scores[j]
|
accumulatedScores[i][id] += scores[j]
|
||||||
}
|
}
|
||||||
start += realTopk
|
start += realTopk
|
||||||
@ -758,6 +782,8 @@ func rankSearchResultDataByPk(ctx context.Context,
|
|||||||
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
score = float32(math.Floor(float64(score)*multiplier+0.5) / multiplier)
|
||||||
}
|
}
|
||||||
ret.Results.Scores = append(ret.Results.Scores, score)
|
ret.Results.Scores = append(ret.Results.Scores, score)
|
||||||
|
loc := pk2DataOffset[i][keys[index]]
|
||||||
|
typeutil.AppendFieldData(ret.Results.FieldsData, searchResults[loc.resultIdx].GetResults().GetFieldsData(), loc.offset)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cockroachdb/errors"
|
"github.com/cockroachdb/errors"
|
||||||
@ -215,7 +216,16 @@ func (t *searchTask) PreExecute(ctx context.Context) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if t.SearchRequest.GetIsAdvanced() {
|
if t.SearchRequest.GetIsAdvanced() {
|
||||||
t.requery = len(t.translatedOutputFields) > 0
|
switch strings.ToLower(paramtable.Get().CommonCfg.HybridSearchRequeryPolicy.GetValue()) {
|
||||||
|
case "always":
|
||||||
|
t.requery = true
|
||||||
|
case "outputvector":
|
||||||
|
t.requery = len(vectorOutputFields) > 0
|
||||||
|
case "outputfields":
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
t.requery = len(t.request.GetOutputFields()) > 0
|
||||||
|
}
|
||||||
err = t.initAdvancedSearchRequest(ctx)
|
err = t.initAdvancedSearchRequest(ctx)
|
||||||
} else {
|
} else {
|
||||||
t.requery = len(vectorOutputFields) > 0
|
t.requery = len(vectorOutputFields) > 0
|
||||||
|
|||||||
@ -317,6 +317,8 @@ type commonConfig struct {
|
|||||||
|
|
||||||
EnableConfigParamTypeCheck ParamItem `refreshable:"true"`
|
EnableConfigParamTypeCheck ParamItem `refreshable:"true"`
|
||||||
ClusterID ParamItem `refreshable:"false"`
|
ClusterID ParamItem `refreshable:"false"`
|
||||||
|
|
||||||
|
HybridSearchRequeryPolicy ParamItem `refreshable:"true"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *commonConfig) init(base *BaseTable) {
|
func (p *commonConfig) init(base *BaseTable) {
|
||||||
@ -1155,6 +1157,7 @@ This helps Milvus-CDC synchronize incremental data`,
|
|||||||
Export: true,
|
Export: true,
|
||||||
}
|
}
|
||||||
p.EnableConfigParamTypeCheck.Init(base.mgr)
|
p.EnableConfigParamTypeCheck.Init(base.mgr)
|
||||||
|
|
||||||
p.ClusterID = ParamItem{
|
p.ClusterID = ParamItem{
|
||||||
Key: "common.clusterID",
|
Key: "common.clusterID",
|
||||||
Version: "2.6.3",
|
Version: "2.6.3",
|
||||||
@ -1174,6 +1177,15 @@ This helps Milvus-CDC synchronize incremental data`,
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
p.ClusterID.Init(base.mgr)
|
p.ClusterID.Init(base.mgr)
|
||||||
|
|
||||||
|
p.HybridSearchRequeryPolicy = ParamItem{
|
||||||
|
Key: "common.requery.hybridSearchPolicy",
|
||||||
|
Version: "2.5.18",
|
||||||
|
DefaultValue: "OutputVector",
|
||||||
|
Doc: `the policy to decide when to do requery in hybrid search, support "always", "outputvector" and "outputfields"`,
|
||||||
|
Export: false,
|
||||||
|
}
|
||||||
|
p.HybridSearchRequeryPolicy.Init(base.mgr)
|
||||||
}
|
}
|
||||||
|
|
||||||
type gpuConfig struct {
|
type gpuConfig struct {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user