fix: panic caused by type assert LocalSegment on Segment (#29018) (#29900)

- Make implementation of LocalWorker and RemoteWorker same.

issue: #29017, #29899
pr: #29018

Signed-off-by: yah01 <yah2er0ne@outlook.com>
Co-authored-by: yah01 <yah2er0ne@outlook.com>
This commit is contained in:
chyezh 2024-01-12 10:08:50 +08:00 committed by GitHub
parent ef7e4aea43
commit f0db26107c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 61 deletions

View File

@ -18,18 +18,12 @@ package querynodev2
import (
"context"
"fmt"
"github.com/samber/lo"
"go.uber.org/zap"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/cluster"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/util/streamrpc"
"github.com/milvus-io/milvus/pkg/log"
"github.com/milvus-io/milvus/pkg/util/merr"
)
var _ cluster.Worker = &LocalWorker{}
@ -45,65 +39,18 @@ func NewLocalWorker(node *QueryNode) *LocalWorker {
}
func (w *LocalWorker) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("segmentIDs", lo.Map(req.GetInfos(), func(info *querypb.SegmentLoadInfo, _ int) int64 {
return info.GetSegmentID()
})),
zap.String("loadScope", req.GetLoadScope().String()),
)
w.node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(),
w.node.composeIndexMeta(req.GetIndexInfoList(), req.GetSchema()), req.GetLoadMeta())
defer w.node.manager.Collection.Unref(req.GetCollectionID(), 1)
log.Info("start to load segments...")
loaded, err := w.node.loader.Load(ctx,
req.GetCollectionID(),
segments.SegmentTypeSealed,
req.GetVersion(),
req.GetInfos()...,
)
if err != nil {
return err
}
w.node.manager.Collection.Ref(req.GetCollectionID(), uint32(len(loaded)))
log.Info("load segments done...",
zap.Int64s("segments", lo.Map(loaded, func(s segments.Segment, _ int) int64 { return s.ID() })))
return err
status, err := w.node.LoadSegments(ctx, req)
return merr.CheckRPCCall(status, err)
}
func (w *LocalWorker) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionID()),
zap.Int64s("segmentIDs", req.GetSegmentIDs()),
zap.String("scope", req.GetScope().String()),
)
log.Info("start to release segments")
sealedCount := 0
for _, id := range req.GetSegmentIDs() {
_, count := w.node.manager.Segment.Remove(id, req.GetScope())
sealedCount += count
}
w.node.manager.Collection.Unref(req.GetCollectionID(), uint32(sealedCount))
return nil
status, err := w.node.ReleaseSegments(ctx, req)
return merr.CheckRPCCall(status, err)
}
func (w *LocalWorker) Delete(ctx context.Context, req *querypb.DeleteRequest) error {
log := log.Ctx(ctx).With(
zap.Int64("collectionID", req.GetCollectionId()),
zap.Int64("segmentID", req.GetSegmentId()),
)
log.Debug("start to process segment delete")
status, err := w.node.Delete(ctx, req)
if err != nil {
return err
}
if status.GetErrorCode() != commonpb.ErrorCode_Success {
return fmt.Errorf(status.GetReason())
}
return nil
return merr.CheckRPCCall(status, err)
}
func (w *LocalWorker) SearchSegments(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {

View File

@ -24,6 +24,7 @@ import (
"github.com/stretchr/testify/suite"
clientv3 "go.etcd.io/etcd/client/v3"
"github.com/milvus-io/milvus-proto/go-api/v2/commonpb"
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
"github.com/milvus-io/milvus/internal/proto/indexpb"
"github.com/milvus-io/milvus/internal/proto/querypb"
@ -112,6 +113,9 @@ func (suite *LocalWorkerTestSuite) TestLoadSegment() {
// load empty
schema := segments.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64)
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: suite.node.session.GetServerID(),
},
CollectionID: suite.collectionID,
Infos: lo.Map(suite.segmentIDs, func(segID int64, _ int) *querypb.SegmentLoadInfo {
return &querypb.SegmentLoadInfo{
@ -129,6 +133,9 @@ func (suite *LocalWorkerTestSuite) TestLoadSegment() {
func (suite *LocalWorkerTestSuite) TestReleaseSegment() {
req := &querypb.ReleaseSegmentsRequest{
Base: &commonpb.MsgBase{
TargetID: suite.node.session.GetServerID(),
},
CollectionID: suite.collectionID,
SegmentIDs: suite.segmentIDs,
}

View File

@ -94,9 +94,8 @@ func retrieveOnSegmentsWithStream(ctx context.Context, segments []Segment, segTy
wg.Add(1)
go func(segment Segment, i int) {
defer wg.Done()
seg := segment.(*LocalSegment)
tr := timerecord.NewTimeRecorder("retrieveOnSegmentsWithStream")
result, err := seg.Retrieve(ctx, plan)
result, err := segment.Retrieve(ctx, plan)
if err != nil {
errs[i] = err
return