diff --git a/.gitignore b/.gitignore index 59a1dbe579..f17a9b5aba 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ **/cmake-build-release/* **/cmake_build_release/* **/cmake_build/* +.cache internal/core/output/* internal/core/build/* @@ -87,4 +88,4 @@ deployments/docker/*/volumes # rocksdb cwrapper_rocksdb_build/ -internal/kv/rocksdb/cwrapper/ \ No newline at end of file +internal/kv/rocksdb/cwrapper/ diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 7d4ed4cbd8..7e671d52a8 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -51,17 +51,19 @@ import ( ) const ( + AnnsFieldKey = "anns_field" + TopKKey = "topk" + MetricTypeKey = "metric_type" + SearchParamsKey = "params" + RoundDecimalKey = "round_decimal" + OffsetKey = "offset" + InsertTaskName = "InsertTask" CreateCollectionTaskName = "CreateCollectionTask" DropCollectionTaskName = "DropCollectionTask" SearchTaskName = "SearchTask" RetrieveTaskName = "RetrieveTask" QueryTaskName = "QueryTask" - AnnsFieldKey = "anns_field" - TopKKey = "topk" - MetricTypeKey = "metric_type" - SearchParamsKey = "params" - RoundDecimalKey = "round_decimal" HasCollectionTaskName = "HasCollectionTask" DescribeCollectionTaskName = "DescribeCollectionTask" GetCollectionStatisticsTaskName = "GetCollectionStatisticsTask" diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 7a0eeee3f5..24e59d0dd1 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -93,10 +93,13 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf if err != nil { return nil, errors.New(TopKKey + " not found in search_params") } - topK, err := strconv.Atoi(topKStr) + topK, err := strconv.ParseInt(topKStr, 0, 64) if err != nil { return nil, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) } + if err := validateTopK(topK); err != nil { + return nil, fmt.Errorf("invalid limit, %w", err) + } metricType, err := funcutil.GetAttrByKeyFromRepeatedKV(MetricTypeKey, searchParamsPair) if err != nil { @@ -112,7 +115,7 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf if err != nil { roundDecimalStr = "-1" } - roundDecimal, err := strconv.Atoi(roundDecimalStr) + roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64) if err != nil { return nil, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) } @@ -122,10 +125,10 @@ func parseQueryInfo(searchParamsPair []*commonpb.KeyValuePair) (*planpb.QueryInf } return &planpb.QueryInfo{ - Topk: int64(topK), + Topk: topK, MetricType: metricType, SearchParams: searchParams, - RoundDecimal: int64(roundDecimal), + RoundDecimal: roundDecimal, }, nil } @@ -242,6 +245,7 @@ func (t *searchTask) PreExecute(ctx context.Context) error { t.SearchRequest.OutputFieldsId = outputFieldIDs plan.OutputFieldIds = outputFieldIDs + t.SearchRequest.Topk = queryInfo.GetTopk() t.SearchRequest.MetricType = queryInfo.GetMetricType() t.SearchRequest.DslType = commonpb.DslType_BoolExprV1 t.SearchRequest.SerializedExprPlan, err = proto.Marshal(plan) @@ -249,10 +253,6 @@ func (t *searchTask) PreExecute(ctx context.Context) error { return err } - t.SearchRequest.Topk = queryInfo.GetTopk() - if err := validateTopK(queryInfo.GetTopk()); err != nil { - return err - } log.Ctx(ctx).Debug("Proxy::searchTask::PreExecute", zap.Int64("msgID", t.ID()), zap.Int64s("plan.OutputFieldIds", plan.GetOutputFieldIds()), zap.String("plan", plan.String())) // may be very large if large term passed. @@ -647,18 +647,6 @@ func reduceSearchResultData(ctx context.Context, searchResultData []*schemapb.Se // } //} -// func printSearchResult(partialSearchResult *internalpb.SearchResults) { -// for i := 0; i < len(partialSearchResult.Hits); i++ { -// testHits := milvuspb.Hits{} -// err := proto.Unmarshal(partialSearchResult.Hits[i], &testHits) -// if err != nil { -// panic(err) -// } -// fmt.Println(testHits.IDs) -// fmt.Println(testHits.Scores) -// } -// } - func (t *searchTask) TraceCtx() context.Context { return t.ctx } diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 460ba33f8a..42394942b9 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -1697,6 +1697,11 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { Value: "invalid", }) + spInvalidTopk65536 := append(spNoTopk, &commonpb.KeyValuePair{ + Key: TopKKey, + Value: "65536", + }) + spNoMetricType := append(spNoTopk, &commonpb.KeyValuePair{ Key: TopKKey, Value: "10", @@ -1727,6 +1732,7 @@ func TestTaskSearch_parseQueryInfo(t *testing.T) { }{ {"No_topk", spNoTopk}, {"Invalid_topk", spInvalidTopk}, + {"Invalid_topk_65536", spInvalidTopk65536}, {"No_Metric_type", spNoMetricType}, {"No_search_params", spNoSearchParams}, {"Invalid_round_decimal", spInvalidRoundDecimal},