From 9d0aa5c20295bb36051ae2928ade2622ec6d52a8 Mon Sep 17 00:00:00 2001 From: Chun Han <116052805+MrPresent-Han@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:25:07 +0800 Subject: [PATCH] fix: empty result when having only one subReq(#36098) (#36128) related: #36098 Signed-off-by: MrPresent-Han Co-authored-by: MrPresent-Han --- internal/querynodev2/segments/result.go | 4 - .../hybridsearch/hybridsearch_test.go | 166 ++++++++++++++++++ 2 files changed, 166 insertions(+), 4 deletions(-) diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 6549835f8b..70c3b28a22 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -127,10 +127,6 @@ func ReduceAdvancedSearchResults(ctx context.Context, results []*internalpb.Sear _, sp := otel.Tracer(typeutil.QueryNodeRole).Start(ctx, "ReduceAdvancedSearchResults") defer sp.End() - if len(results) == 1 { - return results[0], nil - } - channelsMvcc := make(map[string]uint64) relatedDataSize := int64(0) searchResults := &internalpb.SearchResults{ diff --git a/tests/integration/hybridsearch/hybridsearch_test.go b/tests/integration/hybridsearch/hybridsearch_test.go index a0fa638c5a..cc6fd667d0 100644 --- a/tests/integration/hybridsearch/hybridsearch_test.go +++ b/tests/integration/hybridsearch/hybridsearch_test.go @@ -249,6 +249,172 @@ func (s *HybridSearchSuite) TestHybridSearch() { log.Info("TestHybridSearch succeed") } +// this is special case to verify the correctness of hybrid search reduction +func (s *HybridSearchSuite) TestHybridSearchSingleSubReq() { + c := s.Cluster + ctx, cancel := context.WithCancel(c.GetContext()) + defer cancel() + + prefix := "TestHybridSearch" + dbName := "" + collectionName := prefix + funcutil.GenRandomStr() + dim := 128 + rowNum := 3000 + + schema := integration.ConstructSchema(collectionName, dim, true, + &schemapb.FieldSchema{Name: integration.Int64Field, DataType: schemapb.DataType_Int64, IsPrimaryKey: true, AutoID: true}, + &schemapb.FieldSchema{Name: integration.FloatVecField, DataType: schemapb.DataType_FloatVector, TypeParams: []*commonpb.KeyValuePair{{Key: common.DimKey, Value: "128"}}}, + ) + marshaledSchema, err := proto.Marshal(schema) + s.NoError(err) + + createCollectionStatus, err := c.Proxy.CreateCollection(ctx, &milvuspb.CreateCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + Schema: marshaledSchema, + ShardsNum: common.DefaultShardsNum, + }) + s.NoError(err) + + err = merr.Error(createCollectionStatus) + if err != nil { + log.Warn("createCollectionStatus fail reason", zap.Error(err)) + } + + log.Info("CreateCollection result", zap.Any("createCollectionStatus", createCollectionStatus)) + showCollectionsResp, err := c.Proxy.ShowCollections(ctx, &milvuspb.ShowCollectionsRequest{}) + s.NoError(err) + s.True(merr.Ok(showCollectionsResp.GetStatus())) + log.Info("ShowCollections result", zap.Any("showCollectionsResp", showCollectionsResp)) + + fVecColumn := integration.NewFloatVectorFieldData(integration.FloatVecField, rowNum, dim) + hashKeys := integration.GenerateHashKeys(rowNum) + insertResult, err := c.Proxy.Insert(ctx, &milvuspb.InsertRequest{ + DbName: dbName, + CollectionName: collectionName, + FieldsData: []*schemapb.FieldData{fVecColumn}, + HashKeys: hashKeys, + NumRows: uint32(rowNum), + }) + s.NoError(err) + s.True(merr.Ok(insertResult.GetStatus())) + + // flush + flushResp, err := c.Proxy.Flush(ctx, &milvuspb.FlushRequest{ + DbName: dbName, + CollectionNames: []string{collectionName}, + }) + s.NoError(err) + segmentIDs, has := flushResp.GetCollSegIDs()[collectionName] + ids := segmentIDs.GetData() + s.Require().NotEmpty(segmentIDs) + s.Require().True(has) + flushTs, has := flushResp.GetCollFlushTs()[collectionName] + s.True(has) + + segments, err := c.MetaWatcher.ShowSegments() + s.NoError(err) + s.NotEmpty(segments) + for _, segment := range segments { + log.Info("ShowSegments result", zap.String("segment", segment.String())) + } + s.WaitForFlush(ctx, ids, flushTs, dbName, collectionName) + + // load without index on vector fields + loadStatus, err := c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.Error(merr.Error(loadStatus)) + + // create index for float vector + createIndexStatus, err := c.Proxy.CreateIndex(ctx, &milvuspb.CreateIndexRequest{ + CollectionName: collectionName, + FieldName: integration.FloatVecField, + IndexName: "_default_float", + ExtraParams: integration.ConstructIndexParam(dim, integration.IndexFaissIvfFlat, metric.L2), + }) + s.NoError(err) + err = merr.Error(createIndexStatus) + if err != nil { + log.Warn("createIndexStatus fail reason", zap.Error(err)) + } + s.WaitForIndexBuilt(ctx, collectionName, integration.FloatVecField) + + // load with index on vector fields + loadStatus, err = c.Proxy.LoadCollection(ctx, &milvuspb.LoadCollectionRequest{ + DbName: dbName, + CollectionName: collectionName, + }) + s.NoError(err) + s.NoError(merr.Error(loadStatus)) + + // search + expr := fmt.Sprintf("%s > 0", integration.Int64Field) + nq := 1 + topk := 10 + roundDecimal := -1 + + fParams := integration.GetSearchParams(integration.IndexFaissIvfFlat, metric.L2) + fSearchReq := integration.ConstructSearchRequest("", collectionName, expr, + integration.FloatVecField, schemapb.DataType_FloatVector, nil, metric.L2, fParams, nq, dim, topk, roundDecimal) + + hSearchReq := &milvuspb.HybridSearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Requests: []*milvuspb.SearchRequest{fSearchReq}, + OutputFields: []string{integration.FloatVecField}, + } + + // rrf rank hybrid search + rrfParams := make(map[string]float64) + rrfParams[proxy.RRFParamsKey] = 60 + b, err := json.Marshal(rrfParams) + s.NoError(err) + hSearchReq.RankParams = []*commonpb.KeyValuePair{ + {Key: proxy.RankTypeKey, Value: "rrf"}, + {Key: proxy.RankParamsKey, Value: string(b)}, + {Key: proxy.LimitKey, Value: strconv.Itoa(topk)}, + {Key: proxy.RoundDecimalKey, Value: strconv.Itoa(roundDecimal)}, + } + + searchResult, err := c.Proxy.HybridSearch(ctx, hSearchReq) + + s.NoError(merr.CheckRPCCall(searchResult, err)) + + // weighted rank hybrid search + weightsParams := make(map[string][]float64) + weightsParams[proxy.WeightsParamsKey] = []float64{0.5} + b, err = json.Marshal(weightsParams) + s.NoError(err) + + // create a new request preventing data race + hSearchReq = &milvuspb.HybridSearchRequest{ + Base: nil, + DbName: dbName, + CollectionName: collectionName, + PartitionNames: nil, + Requests: []*milvuspb.SearchRequest{fSearchReq}, + OutputFields: []string{integration.FloatVecField}, + } + hSearchReq.RankParams = []*commonpb.KeyValuePair{ + {Key: proxy.RankTypeKey, Value: "weighted"}, + {Key: proxy.RankParamsKey, Value: string(b)}, + {Key: proxy.LimitKey, Value: strconv.Itoa(topk)}, + } + + searchResult, err = c.Proxy.HybridSearch(ctx, hSearchReq) + + s.NoError(merr.CheckRPCCall(searchResult, err)) + s.Equal(topk, len(searchResult.GetResults().GetIds().GetIntId().Data)) + s.Equal(topk, len(searchResult.GetResults().GetScores())) + s.Equal(int64(nq), searchResult.GetResults().GetNumQueries()) + log.Info("TestHybridSearchSingleSubRequest succeed") +} + func TestHybridSearch(t *testing.T) { suite.Run(t, new(HybridSearchSuite)) }