mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-06 17:18:35 +08:00
fix: revert add range search params check in proxy (#32366)
no need to check params in empty segment. #30365 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>
This commit is contained in:
parent
6ef677f79e
commit
365e50b63e
@ -90,11 +90,6 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
||||
searchParamStr = ""
|
||||
}
|
||||
|
||||
err = checkRangeSearchParams(searchParamStr, metricType)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 5. parse group by field
|
||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||
if err != nil {
|
||||
|
||||
@ -2,7 +2,6 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
@ -27,7 +26,6 @@ import (
|
||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||
@ -939,64 +937,6 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
|
||||
return subSearchIdx, resultDataIdx
|
||||
}
|
||||
|
||||
type rangeSearchParams struct {
|
||||
radius float64
|
||||
rangeFilter float64
|
||||
}
|
||||
|
||||
func checkRangeSearchParams(str string, metricType string) error {
|
||||
if len(str) == 0 {
|
||||
// no search params, no need to check
|
||||
return nil
|
||||
}
|
||||
var data map[string]*json.RawMessage
|
||||
err := json.Unmarshal([]byte(str), &data)
|
||||
if err != nil {
|
||||
log.Info("json Unmarshal fail when checkRangeSearchParams")
|
||||
return err
|
||||
}
|
||||
radius, ok := data[radiusKey]
|
||||
// will not do range search, no need to check
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if radius == nil {
|
||||
return merr.WrapErrParameterInvalidMsg("pass invalid type for radius")
|
||||
}
|
||||
var params rangeSearchParams
|
||||
err = json.Unmarshal(*radius, ¶ms.radius)
|
||||
if err != nil {
|
||||
return merr.WrapErrParameterInvalidMsg("must pass numpy type for radius")
|
||||
}
|
||||
|
||||
rangeFilter, ok := data[rangeFilterKey]
|
||||
// not pass range_filter, no need to check
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if rangeFilter == nil {
|
||||
return merr.WrapErrParameterInvalidMsg("pass invalid type for range_filter")
|
||||
}
|
||||
err = json.Unmarshal(*rangeFilter, ¶ms.rangeFilter)
|
||||
if err != nil {
|
||||
return merr.WrapErrParameterInvalidMsg("must pass numpy type for range_filter")
|
||||
}
|
||||
|
||||
if metric.PositivelyRelated(metricType) {
|
||||
if params.radius >= params.rangeFilter {
|
||||
msg := fmt.Sprintf("metric type '%s', range_filter(%f) must be greater than radius(%f)", metricType, params.rangeFilter, params.radius)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
} else {
|
||||
if params.radius <= params.rangeFilter {
|
||||
msg := fmt.Sprintf("metric type '%s', range_filter(%f) must be less than radius(%f)", metricType, params.rangeFilter, params.radius)
|
||||
return merr.WrapErrParameterInvalidMsg(msg)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *searchTask) TraceCtx() context.Context {
|
||||
return t.ctx
|
||||
}
|
||||
|
||||
@ -132,56 +132,6 @@ func getBaseSearchParams() []*commonpb.KeyValuePair {
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseParamsForRangeSearchL2() []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: testFloatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.L2,
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Value: "-1",
|
||||
},
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getBaseParamsForRangeSearchIP() []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
Key: AnnsFieldKey,
|
||||
Value: testFloatVecField,
|
||||
},
|
||||
{
|
||||
Key: TopKKey,
|
||||
Value: "10",
|
||||
},
|
||||
{
|
||||
Key: common.MetricTypeKey,
|
||||
Value: metric.IP,
|
||||
},
|
||||
{
|
||||
Key: RoundDecimalKey,
|
||||
Value: "-1",
|
||||
},
|
||||
{
|
||||
Key: IgnoreGrowingKey,
|
||||
Value: "false",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func getValidSearchParams() []*commonpb.KeyValuePair {
|
||||
return []*commonpb.KeyValuePair{
|
||||
{
|
||||
@ -2113,115 +2063,6 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("check range search params", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
|
||||
invalidTypeRadius := getBaseParamsForRangeSearchL2()
|
||||
invalidTypeRadius = append(invalidTypeRadius, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": null}`,
|
||||
})
|
||||
|
||||
invalidTypeFilter := getBaseParamsForRangeSearchL2()
|
||||
invalidTypeFilter = append(invalidTypeFilter, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": null}`,
|
||||
})
|
||||
|
||||
normalParamWithNoFilter := getBaseParamsForRangeSearchL2()
|
||||
normalParamWithNoFilter = append(normalParamWithNoFilter, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10}`,
|
||||
})
|
||||
|
||||
normalParamForIP := getBaseParamsForRangeSearchIP()
|
||||
normalParamForIP = append(normalParamForIP, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
||||
})
|
||||
|
||||
normalParamForL2 := getBaseParamsForRangeSearchL2()
|
||||
normalParamForL2 = append(normalParamForL2, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
||||
})
|
||||
|
||||
abnormalParamForIP := getBaseParamsForRangeSearchIP()
|
||||
abnormalParamForIP = append(abnormalParamForIP, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
||||
})
|
||||
|
||||
abnormalParamForL2 := getBaseParamsForRangeSearchL2()
|
||||
abnormalParamForL2 = append(abnormalParamForL2, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
||||
})
|
||||
|
||||
wrongTypeRadius := getBaseParamsForRangeSearchIP()
|
||||
wrongTypeRadius = append(wrongTypeRadius, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": "ab"}`,
|
||||
})
|
||||
|
||||
wrongTypeFilter := getBaseParamsForRangeSearchIP()
|
||||
wrongTypeFilter = append(wrongTypeFilter, &commonpb.KeyValuePair{
|
||||
Key: SearchParamsKey,
|
||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": "20"}`,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
validParams []*commonpb.KeyValuePair
|
||||
}{
|
||||
{"normalParam", normalParam},
|
||||
{"normalParamWithNoFilter", normalParamWithNoFilter},
|
||||
{"normalParamForIP", normalParamForIP},
|
||||
{"normalParamForL2", normalParamForL2},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, info)
|
||||
})
|
||||
}
|
||||
|
||||
tests = []struct {
|
||||
description string
|
||||
validParams []*commonpb.KeyValuePair
|
||||
}{
|
||||
{"abnormalParamForIP", abnormalParamForIP},
|
||||
{"abnormalParamForL2", abnormalParamForL2},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
}
|
||||
|
||||
tests = []struct {
|
||||
description string
|
||||
validParams []*commonpb.KeyValuePair
|
||||
}{
|
||||
{"invalidTypeRadius", invalidTypeRadius},
|
||||
{"invalidTypeFilter", invalidTypeFilter},
|
||||
{"wrongTypeRadius", wrongTypeRadius},
|
||||
{"wrongTypeFilter", wrongTypeFilter},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
||||
assert.Nil(t, info)
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("check iterator and groupBy", func(t *testing.T) {
|
||||
normalParam := getValidSearchParams()
|
||||
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user