mirror of
https://gitee.com/milvus-io/milvus.git
synced 2025-12-07 09:38:39 +08:00
fix: revert add range search params check in proxy (#32366)
no need to check params in empty segment. #30365 Signed-off-by: lixinguo <xinguo.li@zilliz.com> Co-authored-by: lixinguo <xinguo.li@zilliz.com>
This commit is contained in:
parent
6ef677f79e
commit
365e50b63e
@ -90,11 +90,6 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
|
|||||||
searchParamStr = ""
|
searchParamStr = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
err = checkRangeSearchParams(searchParamStr, metricType)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. parse group by field
|
// 5. parse group by field
|
||||||
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
groupByFieldName, err := funcutil.GetAttrByKeyFromRepeatedKV(GroupByFieldKey, searchParamsPair)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -2,7 +2,6 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
"math"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -27,7 +26,6 @@ import (
|
|||||||
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
"github.com/milvus-io/milvus/pkg/util/commonpbutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
"github.com/milvus-io/milvus/pkg/util/funcutil"
|
||||||
"github.com/milvus-io/milvus/pkg/util/merr"
|
"github.com/milvus-io/milvus/pkg/util/merr"
|
||||||
"github.com/milvus-io/milvus/pkg/util/metric"
|
|
||||||
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
"github.com/milvus-io/milvus/pkg/util/paramtable"
|
||||||
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
"github.com/milvus-io/milvus/pkg/util/timerecord"
|
||||||
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
"github.com/milvus-io/milvus/pkg/util/tsoutil"
|
||||||
@ -939,64 +937,6 @@ func selectHighestScoreIndex(subSearchResultData []*schemapb.SearchResultData, s
|
|||||||
return subSearchIdx, resultDataIdx
|
return subSearchIdx, resultDataIdx
|
||||||
}
|
}
|
||||||
|
|
||||||
type rangeSearchParams struct {
|
|
||||||
radius float64
|
|
||||||
rangeFilter float64
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkRangeSearchParams(str string, metricType string) error {
|
|
||||||
if len(str) == 0 {
|
|
||||||
// no search params, no need to check
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var data map[string]*json.RawMessage
|
|
||||||
err := json.Unmarshal([]byte(str), &data)
|
|
||||||
if err != nil {
|
|
||||||
log.Info("json Unmarshal fail when checkRangeSearchParams")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
radius, ok := data[radiusKey]
|
|
||||||
// will not do range search, no need to check
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if radius == nil {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("pass invalid type for radius")
|
|
||||||
}
|
|
||||||
var params rangeSearchParams
|
|
||||||
err = json.Unmarshal(*radius, ¶ms.radius)
|
|
||||||
if err != nil {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("must pass numpy type for radius")
|
|
||||||
}
|
|
||||||
|
|
||||||
rangeFilter, ok := data[rangeFilterKey]
|
|
||||||
// not pass range_filter, no need to check
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if rangeFilter == nil {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("pass invalid type for range_filter")
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(*rangeFilter, ¶ms.rangeFilter)
|
|
||||||
if err != nil {
|
|
||||||
return merr.WrapErrParameterInvalidMsg("must pass numpy type for range_filter")
|
|
||||||
}
|
|
||||||
|
|
||||||
if metric.PositivelyRelated(metricType) {
|
|
||||||
if params.radius >= params.rangeFilter {
|
|
||||||
msg := fmt.Sprintf("metric type '%s', range_filter(%f) must be greater than radius(%f)", metricType, params.rangeFilter, params.radius)
|
|
||||||
return merr.WrapErrParameterInvalidMsg(msg)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if params.radius <= params.rangeFilter {
|
|
||||||
msg := fmt.Sprintf("metric type '%s', range_filter(%f) must be less than radius(%f)", metricType, params.rangeFilter, params.radius)
|
|
||||||
return merr.WrapErrParameterInvalidMsg(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *searchTask) TraceCtx() context.Context {
|
func (t *searchTask) TraceCtx() context.Context {
|
||||||
return t.ctx
|
return t.ctx
|
||||||
}
|
}
|
||||||
|
|||||||
@ -132,56 +132,6 @@ func getBaseSearchParams() []*commonpb.KeyValuePair {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getBaseParamsForRangeSearchL2() []*commonpb.KeyValuePair {
|
|
||||||
return []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: AnnsFieldKey,
|
|
||||||
Value: testFloatVecField,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: TopKKey,
|
|
||||||
Value: "10",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: metric.L2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: RoundDecimalKey,
|
|
||||||
Value: "-1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: IgnoreGrowingKey,
|
|
||||||
Value: "false",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getBaseParamsForRangeSearchIP() []*commonpb.KeyValuePair {
|
|
||||||
return []*commonpb.KeyValuePair{
|
|
||||||
{
|
|
||||||
Key: AnnsFieldKey,
|
|
||||||
Value: testFloatVecField,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: TopKKey,
|
|
||||||
Value: "10",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: common.MetricTypeKey,
|
|
||||||
Value: metric.IP,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: RoundDecimalKey,
|
|
||||||
Value: "-1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Key: IgnoreGrowingKey,
|
|
||||||
Value: "false",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getValidSearchParams() []*commonpb.KeyValuePair {
|
func getValidSearchParams() []*commonpb.KeyValuePair {
|
||||||
return []*commonpb.KeyValuePair{
|
return []*commonpb.KeyValuePair{
|
||||||
{
|
{
|
||||||
@ -2113,115 +2063,6 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
t.Run("check range search params", func(t *testing.T) {
|
|
||||||
normalParam := getValidSearchParams()
|
|
||||||
|
|
||||||
invalidTypeRadius := getBaseParamsForRangeSearchL2()
|
|
||||||
invalidTypeRadius = append(invalidTypeRadius, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": null}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
invalidTypeFilter := getBaseParamsForRangeSearchL2()
|
|
||||||
invalidTypeFilter = append(invalidTypeFilter, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": null}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
normalParamWithNoFilter := getBaseParamsForRangeSearchL2()
|
|
||||||
normalParamWithNoFilter = append(normalParamWithNoFilter, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": 10}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
normalParamForIP := getBaseParamsForRangeSearchIP()
|
|
||||||
normalParamForIP = append(normalParamForIP, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
normalParamForL2 := getBaseParamsForRangeSearchL2()
|
|
||||||
normalParamForL2 = append(normalParamForL2, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
abnormalParamForIP := getBaseParamsForRangeSearchIP()
|
|
||||||
abnormalParamForIP = append(abnormalParamForIP, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": 20, "range_filter": 10}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
abnormalParamForL2 := getBaseParamsForRangeSearchL2()
|
|
||||||
abnormalParamForL2 = append(abnormalParamForL2, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": 20}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
wrongTypeRadius := getBaseParamsForRangeSearchIP()
|
|
||||||
wrongTypeRadius = append(wrongTypeRadius, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": "ab"}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
wrongTypeFilter := getBaseParamsForRangeSearchIP()
|
|
||||||
wrongTypeFilter = append(wrongTypeFilter, &commonpb.KeyValuePair{
|
|
||||||
Key: SearchParamsKey,
|
|
||||||
Value: `{"nprobe": 10, "radius": 10, "range_filter": "20"}`,
|
|
||||||
})
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
description string
|
|
||||||
validParams []*commonpb.KeyValuePair
|
|
||||||
}{
|
|
||||||
{"normalParam", normalParam},
|
|
||||||
{"normalParamWithNoFilter", normalParamWithNoFilter},
|
|
||||||
{"normalParamForIP", normalParamForIP},
|
|
||||||
{"normalParamForL2", normalParamForL2},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.description, func(t *testing.T) {
|
|
||||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, info)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
tests = []struct {
|
|
||||||
description string
|
|
||||||
validParams []*commonpb.KeyValuePair
|
|
||||||
}{
|
|
||||||
{"abnormalParamForIP", abnormalParamForIP},
|
|
||||||
{"abnormalParamForL2", abnormalParamForL2},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.description, func(t *testing.T) {
|
|
||||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
|
||||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
|
||||||
assert.Nil(t, info)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
tests = []struct {
|
|
||||||
description string
|
|
||||||
validParams []*commonpb.KeyValuePair
|
|
||||||
}{
|
|
||||||
{"invalidTypeRadius", invalidTypeRadius},
|
|
||||||
{"invalidTypeFilter", invalidTypeFilter},
|
|
||||||
{"wrongTypeRadius", wrongTypeRadius},
|
|
||||||
{"wrongTypeFilter", wrongTypeFilter},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
t.Run(test.description, func(t *testing.T) {
|
|
||||||
info, _, err := parseSearchInfo(test.validParams, nil, false)
|
|
||||||
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
|
|
||||||
assert.Nil(t, info)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("check iterator and groupBy", func(t *testing.T) {
|
t.Run("check iterator and groupBy", func(t *testing.T) {
|
||||||
normalParam := getValidSearchParams()
|
normalParam := getValidSearchParams()
|
||||||
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
normalParam = append(normalParam, &commonpb.KeyValuePair{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user