Retrieve segments concurrently (#24245)

Signed-off-by: bigsheeper <yihao.dai@zilliz.com>
This commit is contained in:
yihao.dai 2023-05-19 18:19:24 +08:00 committed by GitHub
parent 9729b6079d
commit 45aa9779e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 16 deletions

View File

@ -18,6 +18,7 @@
#include "query/generated/ExecExprVisitor.h"
#include "segcore/SegmentGrowing.h"
#include "utils/Json.h"
#include "log/Log.h"
namespace milvus::query {

View File

@ -18,31 +18,71 @@ package segments
import (
"context"
"fmt"
"sync"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus/internal/proto/segcorepb"
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/pkg/metrics"
"github.com/milvus-io/milvus/pkg/util/paramtable"
"github.com/milvus-io/milvus/pkg/util/timerecord"
. "github.com/milvus-io/milvus/pkg/util/typeutil"
)
// retrieveOnSegments performs retrieve on listed segments
// all segment ids are validated before calling this function
func retrieveOnSegments(ctx context.Context, manager *Manager, segType SegmentType, plan *RetrievePlan, segIDs []UniqueID, vcm storage.ChunkManager) ([]*segcorepb.RetrieveResults, error) {
var retrieveResults []*segcorepb.RetrieveResults
var (
resultCh = make(chan *segcorepb.RetrieveResults, len(segIDs))
errs = make([]error, len(segIDs))
wg sync.WaitGroup
)
for _, segID := range segIDs {
segment, _ := manager.Segment.Get(segID).(*LocalSegment)
if segment == nil {
continue
}
result, err := segment.Retrieve(ctx, plan)
label := metrics.SealedSegmentLabel
if segType == commonpb.SegmentState_Growing {
label = metrics.GrowingSegmentLabel
}
for i, segID := range segIDs {
wg.Add(1)
go func(segID int64, i int) {
defer wg.Done()
segment, _ := manager.Segment.Get(segID).(*LocalSegment)
if segment == nil {
errs[i] = nil
return
}
tr := timerecord.NewTimeRecorder("retrieveOnSegments")
result, err := segment.Retrieve(ctx, plan)
if err != nil {
errs[i] = err
return
}
if err = segment.FillIndexedFieldsData(ctx, vcm, result); err != nil {
errs[i] = err
return
}
errs[i] = nil
resultCh <- result
metrics.QueryNodeSQSegmentLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel, label).Observe(float64(tr.ElapseSpan().Milliseconds()))
}(segID, i)
}
wg.Wait()
close(resultCh)
for _, err := range errs {
if err != nil {
return nil, err
}
if err := segment.FillIndexedFieldsData(ctx, vcm, result); err != nil {
return nil, err
}
}
var retrieveResults []*segcorepb.RetrieveResults
for result := range resultCh {
retrieveResults = append(retrieveResults, result)
}
return retrieveResults, nil
}

View File

@ -145,6 +145,33 @@ func (suite *RetrieveSuite) TestRetrieveGrowing() {
suite.Len(res[0].Offset, 3)
}
func (suite *RetrieveSuite) TestRetrieveNonExistSegment() {
plan, err := genSimpleRetrievePlan(suite.collection)
suite.NoError(err)
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{999},
nil)
suite.NoError(err)
suite.Len(res, 0)
}
func (suite *RetrieveSuite) TestRetrieveNilSegment() {
plan, err := genSimpleRetrievePlan(suite.collection)
suite.NoError(err)
DeleteSegment(suite.sealed)
res, _, _, err := RetrieveHistorical(context.TODO(), suite.manager, plan,
suite.collectionID,
[]int64{suite.partitionID},
[]int64{suite.sealed.ID()},
nil)
suite.ErrorIs(err, ErrSegmentReleased)
suite.Len(res, 0)
}
func TestRetrieve(t *testing.T) {
suite.Run(t, new(RetrieveSuite))
}

View File

@ -396,6 +396,8 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
zap.Int64("collectionID", s.Collection()),
zap.Int64("partitionID", s.Partition()),
zap.Int64("segmentID", s.ID()),
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.typ.String()),
)
span := trace.SpanFromContext(ctx)
@ -421,14 +423,10 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
)
metrics.QueryNodeSQSegmentLatencyInCore.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()),
metrics.QueryLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("cgo retrieve done", zap.Duration("timeTaken", tr.ElapseSpan()))
return nil, nil
}).Await()
log.Debug("do retrieve on segment",
zap.Int64("msgID", plan.msgID),
zap.String("segmentType", s.typ.String()),
)
if err := HandleCStatus(&status, "Retrieve failed"); err != nil {
return nil, err
}
@ -438,7 +436,7 @@ func (s *LocalSegment) Retrieve(ctx context.Context, plan *RetrievePlan) (*segco
return nil, err
}
log.Debug("retrieve result",
log.Debug("retrieve segment done",
zap.Int("resultNum", len(result.Offset)),
)