diff --git a/internal/querynodev2/local_worker.go b/internal/querynodev2/local_worker.go index 790ac15a55..cc05af99dd 100644 --- a/internal/querynodev2/local_worker.go +++ b/internal/querynodev2/local_worker.go @@ -18,18 +18,12 @@ package querynodev2 import ( "context" - "fmt" - "github.com/samber/lo" - "go.uber.org/zap" - - "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/querynodev2/cluster" - "github.com/milvus-io/milvus/internal/querynodev2/segments" "github.com/milvus-io/milvus/internal/util/streamrpc" - "github.com/milvus-io/milvus/pkg/log" + "github.com/milvus-io/milvus/pkg/util/merr" ) var _ cluster.Worker = &LocalWorker{} @@ -45,65 +39,18 @@ func NewLocalWorker(node *QueryNode) *LocalWorker { } func (w *LocalWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error { - log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.GetCollectionID()), - zap.Int64s("segmentIDs", lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) int64 { - return info.GetSegmentID() - })), - zap.String("loadScope", req.GetLoadScope().String()), - ) - w.node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), - w.node.composeIndexMeta(req.GetIndexInfoList(), req.GetSchema()), req.GetLoadMeta()) - defer w.node.manager.Collection.Unref(req.GetCollectionID(), 1) - log.Info("start to load segments...") - loaded, err := w.node.loader.Load(ctx, - req.GetCollectionID(), - segments.SegmentTypeSealed, - req.GetVersion(), - req.GetInfos()..., - ) - if err != nil { - return err - } - - w.node.manager.Collection.Ref(req.GetCollectionID(), uint32(len(loaded))) - - log.Info("load segments done...", - zap.Int64s("segments", lo.Map(loaded, func(s segments.Segment, _ int) int64 { return s.ID() }))) - return err + status, err := w.node.LoadSegments(ctx, req) + return merr.CheckRPCCall(status, err) } func (w *LocalWorker) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) error { - log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.GetCollectionID()), - zap.Int64s("segmentIDs", req.GetSegmentIDs()), - zap.String("scope", req.GetScope().String()), - ) - log.Info("start to release segments") - sealedCount := 0 - for _, id := range req.GetSegmentIDs() { - _, count := w.node.manager.Segment.Remove(id, req.GetScope()) - sealedCount += count - } - w.node.manager.Collection.Unref(req.GetCollectionID(), uint32(sealedCount)) - - return nil + status, err := w.node.ReleaseSegments(ctx, req) + return merr.CheckRPCCall(status, err) } func (w *LocalWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error { - log := log.Ctx(ctx).With( - zap.Int64("collectionID", req.GetCollectionId()), - zap.Int64("segmentID", req.GetSegmentId()), - ) - log.Debug("start to process segment delete") status, err := w.node.Delete(ctx, req) - if err != nil { - return err - } - if status.GetErrorCode() != commonpb.ErrorCode_Success { - return fmt.Errorf(status.GetReason()) - } - return nil + return merr.CheckRPCCall(status, err) } func (w *LocalWorker) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) { diff --git a/internal/querynodev2/local_worker_test.go b/internal/querynodev2/local_worker_test.go index 65a4164841..b34becbf8a 100644 --- a/internal/querynodev2/local_worker_test.go +++ b/internal/querynodev2/local_worker_test.go @@ -24,6 +24,7 @@ import ( "github.com/stretchr/testify/suite" clientv3 "go.etcd.io/etcd/client/v3" + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" "github.com/milvus-io/milvus-proto/go-api/v2/schemapb" "github.com/milvus-io/milvus/internal/proto/indexpb" "github.com/milvus-io/milvus/internal/proto/querypb" @@ -112,6 +113,9 @@ func (suite *LocalWorkerTestSuite) TestLoadSegment() { // load empty schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64) req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + TargetID: suite.node.session.GetServerID(), + }, CollectionID: suite.collectionID, Infos: lo.Map(suite.segmentIDs, func(segID int64, _ int) *querypb.SegmentLoadInfo { return &querypb.SegmentLoadInfo{ @@ -129,6 +133,9 @@ func (suite *LocalWorkerTestSuite) TestLoadSegment() { func (suite *LocalWorkerTestSuite) TestReleaseSegment() { req := &querypb.ReleaseSegmentsRequest{ + Base: &commonpb.MsgBase{ + TargetID: suite.node.session.GetServerID(), + }, CollectionID: suite.collectionID, SegmentIDs: suite.segmentIDs, } diff --git a/internal/querynodev2/segments/retrieve.go b/internal/querynodev2/segments/retrieve.go index 26932ed22f..bfff4fd021 100644 --- a/internal/querynodev2/segments/retrieve.go +++ b/internal/querynodev2/segments/retrieve.go @@ -94,9 +94,8 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy wg.Add(1) go func(segment Segment, i int) { defer wg.Done() - seg := segment.(*LocalSegment) tr := timerecord.NewTimeRecorder("retrieveOnSegmentsWithStream") - result, err := seg.Retrieve(ctx, plan) + result, err := segment.Retrieve(ctx, plan) if err != nil { errs[i] = err return