diff --git a/internal/proxy/task.go b/internal/proxy/task.go index d1eb2bd186..934c049e21 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -1568,7 +1568,7 @@ func reduceSearchResultsParallel(hits [][]*milvuspb.Hits, nq, availableQueryNode return ret } -func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) *milvuspb.SearchResults { +func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string, maxParallel int) (*milvuspb.SearchResults, error) { log.Debug("reduceSearchResultDataParallel", zap.Any("NumOfGoRoutines", maxParallel)) ret := &milvuspb.SearchResults{ @@ -1593,10 +1593,12 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat const minFloat32 = -1 * float32(math.MaxFloat32) // TODO(yukun): Use parallel function + realTopK := -1 for idx := 0; idx < nq; idx++ { locs := make([]int, availableQueryNodeNum) - for j := 0; j < topk; j++ { + j := 0 + for ; j < topk; j++ { valid := false choice, maxDistance := 0, minFloat32 for q, loc := range locs { // query num, the number of ways to merge @@ -1696,7 +1698,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat } default: log.Debug("Not supported field type") - return nil + return nil, fmt.Errorf("not supported field type: %s", fieldData.Type.String()) } case *schemapb.FieldData_Vectors: dim := fieldType.Vectors.Dim @@ -1729,9 +1731,15 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat ret.Results.Scores = append(ret.Results.Scores, searchResultData[choice].Scores[idx*topk+choiceOffset]) locs[choice]++ } - + if realTopK != -1 && realTopK != j { + log.Warn("Proxy Reduce Search Result", zap.Error(errors.New("the length (topk) between all result of query is different"))) + // return nil, errors.New("the length (topk) between all result of query is different") + } + realTopK = j } + ret.Results.TopK = int64(realTopK) + if metricType != "IP" { for k := range ret.Results.Scores { ret.Results.Scores[k] *= -1 @@ -1742,7 +1750,7 @@ func reduceSearchResultDataParallel(searchResultData []*schemapb.SearchResultDat // return nil // } - return ret + return ret, nil } func reduceSearchResultsSerial(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { @@ -1767,7 +1775,7 @@ func reduceSearchResults(hits [][]*milvuspb.Hits, nq, availableQueryNodeNum, top return reduceSearchResultsParallelByCPU(hits, nq, availableQueryNodeNum, topk, metricType) } -func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) *milvuspb.SearchResults { +func reduceSearchResultData(searchResultData []*schemapb.SearchResultData, nq, availableQueryNodeNum, topk int, metricType string) (*milvuspb.SearchResults, error) { t := time.Now() defer func() { log.Debug("reduceSearchResults", zap.Any("time cost", time.Since(t))) @@ -1853,7 +1861,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { } nq := results[0].NumQueries - topk := results[0].TopK + topk := 0 + for _, partialResult := range results { + topk = getMax(topk, int(partialResult.TopK)) + } if nq <= 0 { st.result = &milvuspb.SearchResults{ Status: &commonpb.Status{ @@ -1864,7 +1875,10 @@ func (st *SearchTask) PostExecute(ctx context.Context) error { return nil } - st.result = reduceSearchResultData(results, int(nq), availableQueryNodeNum, int(topk), searchResults[0].MetricType) + st.result, err = reduceSearchResultData(results, int(nq), availableQueryNodeNum, topk, searchResults[0].MetricType) + if err != nil { + return err + } schema, err := globalMetaCache.GetCollectionSchema(ctx, st.query.CollectionName) if err != nil {