diff --git a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp index 29bcd596c7..671ef9ae56 100644 --- a/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp +++ b/internal/core/src/query/visitors/ExecPlanNodeVisitor.cpp @@ -18,6 +18,7 @@ #include "query/generated/ExecExprVisitor.h" #include "segcore/SegmentGrowing.h" #include "utils/Json.h" +#include "log/Log.h" namespace milvus::query { diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index 12310f4f53..7c2932e251 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -18,31 +18,71 @@ package segments import ( "context" + "fmt" + "sync" + "github.com/milvus-io/milvus-proto/go-api/commonpb" "github.com/milvus-io/milvus/internal/proto/segcorepb" "github.com/milvus-io/milvus/internal/storage" + "github.com/milvus-io/milvus/pkg/metrics" + "github.com/milvus-io/milvus/pkg/util/paramtable" + "github.com/milvus-io/milvus/pkg/util/timerecord" . "github.com/milvus-io/milvus/pkg/util/typeutil" ) // retrieveOnSegments performs retrieve on listed segments // all segment ids are validated before calling this function func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) { - var retrieveResults []*segcorepb.RetrieveResults + var ( + resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs)) + errs = make([]error, len(segIDs)) + wg sync.WaitGroup + ) - for _, segID := range segIDs { - segment, _ := manager.Segment.Get(segID).(*LocalSegment) - if segment == nil { - continue - } - result, err := segment.Retrieve(ctx, plan) + label := metrics.SealedSegmentLabel + if segType == commonpb.SegmentState_Growing { + label = metrics.GrowingSegmentLabel + } + + for i, segID := range segIDs { + wg.Add(1) + go func(segID int64, i int) { + defer wg.Done() + segment, _ := manager.Segment.Get(segID).(*LocalSegment) + if segment == nil { + errs[i] = nil + return + } + tr := timerecord.NewTimeRecorder("retrieveOnSegments") + result, err := segment.Retrieve(ctx, plan) + if err != nil { + errs[i] = err + return + } + if err = segment.FillIndexedFieldsData(ctx, vcm, result); err != nil { + errs[i] = err + return + } + errs[i] = nil + resultCh <- result + metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), + metrics.QueryLabel, label).Observe(float64(tr.ElapseSpan().Milliseconds())) + }(segID, i) + } + wg.Wait() + close(resultCh) + + for _, err := range errs { if err != nil { return nil, err } - if err := segment.FillIndexedFieldsData(ctx, vcm, result); err != nil { - return nil, err - } + } + + var retrieveResults []*segcorepb.RetrieveResults + for result := range resultCh { retrieveResults = append(retrieveResults, result) } + return retrieveResults, nil } diff --git a/internal/querynodev2/segments/retrieve_test.go b/internal/querynodev2/segments/retrieve_test.go index a9fd44bb16..b357f4699f 100644 --- a/internal/querynodev2/segments/retrieve_test.go +++ b/internal/querynodev2/segments/retrieve_test.go @@ -145,6 +145,33 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() { suite.Len(res[0].Offset, 3) } +func (suite *RetrieveSuite) TestRetrieveNonExistSegment() { + plan, err := genSimpleRetrievePlan(suite.collection) + suite.NoError(err) + + res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, + suite.collectionID, + []int64{suite.partitionID}, + []int64{999}, + nil) + suite.NoError(err) + suite.Len(res, 0) +} + +func (suite *RetrieveSuite) TestRetrieveNilSegment() { + plan, err := genSimpleRetrievePlan(suite.collection) + suite.NoError(err) + + DeleteSegment(suite.sealed) + res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan, + suite.collectionID, + []int64{suite.partitionID}, + []int64{suite.sealed.ID()}, + nil) + suite.ErrorIs(err, ErrSegmentReleased) + suite.Len(res, 0) +} + func TestRetrieve(t *testing.T) { suite.Run(t, new(RetrieveSuite)) } diff --git a/internal/querynodev2/segments/segment.go b/internal/querynodev2/segments/segment.go index 5b88d845be..12114ddf19 100644 --- a/internal/querynodev2/segments/segment.go +++ b/internal/querynodev2/segments/segment.go @@ -396,6 +396,8 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco zap.Int64("collectionID", s.Collection()), zap.Int64("partitionID", s.Partition()), zap.Int64("segmentID", s.ID()), + zap.Int64("msgID", plan.msgID), + zap.String("segmentType", s.typ.String()), ) span := trace.SpanFromContext(ctx) @@ -421,14 +423,10 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco ) metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds())) + log.Debug("cgo retrieve done", zap.Duration("timeTaken", tr.ElapseSpan())) return nil, nil }).Await() - log.Debug("do retrieve on segment", - zap.Int64("msgID", plan.msgID), - zap.String("segmentType", s.typ.String()), - ) - if err := HandleCStatus(&status, "Retrieve failed"); err != nil { return nil, err } @@ -438,7 +436,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco return nil, err } - log.Debug("retrieve result", + log.Debug("retrieve segment done", zap.Int("resultNum", len(result.Offset)), )